Skip to content

Getting Started

Prerequisites

Here we assume that you are familiar with Lux.jl. If not please take a look at the Lux.jl tutoials.

Boltz.jl is just like Lux.jl but comes with more "batteries included". Let's start by defining an MLP model.

julia
using Lux, Boltz, Random

Multi-Layer Perceptron

If we were to do this in Lux.jl we would write the following:

julia
model = Chain(Dense(784, 256, relu), Dense(256, 10))
Chain(
    layer_1 = Dense(784 => 256, relu),  # 200_960 parameters
    layer_2 = Dense(256 => 10),         # 2_570 parameters
)         # Total: 203_530 parameters,
          #        plus 0 states.

But in Boltz.jl we can do this:

julia
model = Layers.MLP(784, (256, 10), relu)
MLP(
    chain = Chain(
        block1 = DenseNormActDropoutBlock(
            block = Chain(
                dense = Dense(784 => 256, relu),  # 200_960 parameters
            ),
        ),
        block2 = DenseNormActDropoutBlock(
            block = Chain(
                dense = Dense(256 => 10),  # 2_570 parameters
            ),
        ),
    ),
)         # Total: 203_530 parameters,
          #        plus 0 states.

The MLP function is just a convenience wrapper around Lux.Chain that constructs a multi-layer perceptron with the given number of layers and activation function.

How about VGG?

Let's take a look at the Vision module. We can construct a VGG model with the following code:

julia
Vision.VGG(13)
VGG(
    layer = Chain(
        feature_extractor = VGGFeatureExtractor(
            model = Chain(
                layer_1 = ConvNormActivation(
                    model = Chain(
                        block1 = ConvNormActivationBlock(
                            block = Conv((3, 3), 3 => 64, relu, pad=1),  # 1_792 parameters
                        ),
                        block2 = ConvNormActivationBlock(
                            block = Conv((3, 3), 64 => 64, relu, pad=1),  # 36_928 parameters
                        ),
                    ),
                ),
                layer_2 = MaxPool((2, 2)),
                layer_3 = ConvNormActivation(
                    model = Chain(
                        block1 = ConvNormActivationBlock(
                            block = Conv((3, 3), 64 => 128, relu, pad=1),  # 73_856 parameters
                        ),
                        block2 = ConvNormActivationBlock(
                            block = Conv((3, 3), 128 => 128, relu, pad=1),  # 147_584 parameters
                        ),
                    ),
                ),
                layer_4 = MaxPool((2, 2)),
                layer_5 = ConvNormActivation(
                    model = Chain(
                        block1 = ConvNormActivationBlock(
                            block = Conv((3, 3), 128 => 256, relu, pad=1),  # 295_168 parameters
                        ),
                        block2 = ConvNormActivationBlock(
                            block = Conv((3, 3), 256 => 256, relu, pad=1),  # 590_080 parameters
                        ),
                    ),
                ),
                layer_6 = MaxPool((2, 2)),
                layer_7 = ConvNormActivation(
                    model = Chain(
                        block1 = ConvNormActivationBlock(
                            block = Conv((3, 3), 256 => 512, relu, pad=1),  # 1_180_160 parameters
                        ),
                        block2 = ConvNormActivationBlock(
                            block = Conv((3, 3), 512 => 512, relu, pad=1),  # 2_359_808 parameters
                        ),
                    ),
                ),
                layer_8 = MaxPool((2, 2)),
                layer_9 = ConvNormActivation(
                    model = Chain(
                        block1 = ConvNormActivationBlock(
                            block = Conv((3, 3), 512 => 512, relu, pad=1),  # 2_359_808 parameters
                        ),
                        block2 = ConvNormActivationBlock(
                            block = Conv((3, 3), 512 => 512, relu, pad=1),  # 2_359_808 parameters
                        ),
                    ),
                ),
                layer_10 = MaxPool((2, 2)),
            ),
        ),
        classifier = VGGClassifier(
            model = Chain(
                layer_1 = Lux.FlattenLayer{Nothing}(nothing),
                layer_2 = Dense(25088 => 4096, relu),  # 102_764_544 parameters
                layer_3 = Dropout(0.5),
                layer_4 = Dense(4096 => 4096, relu),  # 16_781_312 parameters
                layer_5 = Dropout(0.5),
                layer_6 = Dense(4096 => 1000),  # 4_097_000 parameters
            ),
        ),
    ),
)         # Total: 133_047_848 parameters,
          #        plus 4 states.

We can also load pretrained ImageNet weights using

Load JLD2

You need to load JLD2 before being able to load pretrained weights.

julia
using JLD2

Vision.VGG(13; pretrained=true)
VGG(
    layer = Chain(
        feature_extractor = VGGFeatureExtractor(
            model = Chain(
                layer_1 = ConvNormActivation(
                    model = Chain(
                        block1 = ConvNormActivationBlock(
                            block = Conv((3, 3), 3 => 64, relu, pad=1),  # 1_792 parameters
                        ),
                        block2 = ConvNormActivationBlock(
                            block = Conv((3, 3), 64 => 64, relu, pad=1),  # 36_928 parameters
                        ),
                    ),
                ),
                layer_2 = MaxPool((2, 2)),
                layer_3 = ConvNormActivation(
                    model = Chain(
                        block1 = ConvNormActivationBlock(
                            block = Conv((3, 3), 64 => 128, relu, pad=1),  # 73_856 parameters
                        ),
                        block2 = ConvNormActivationBlock(
                            block = Conv((3, 3), 128 => 128, relu, pad=1),  # 147_584 parameters
                        ),
                    ),
                ),
                layer_4 = MaxPool((2, 2)),
                layer_5 = ConvNormActivation(
                    model = Chain(
                        block1 = ConvNormActivationBlock(
                            block = Conv((3, 3), 128 => 256, relu, pad=1),  # 295_168 parameters
                        ),
                        block2 = ConvNormActivationBlock(
                            block = Conv((3, 3), 256 => 256, relu, pad=1),  # 590_080 parameters
                        ),
                    ),
                ),
                layer_6 = MaxPool((2, 2)),
                layer_7 = ConvNormActivation(
                    model = Chain(
                        block1 = ConvNormActivationBlock(
                            block = Conv((3, 3), 256 => 512, relu, pad=1),  # 1_180_160 parameters
                        ),
                        block2 = ConvNormActivationBlock(
                            block = Conv((3, 3), 512 => 512, relu, pad=1),  # 2_359_808 parameters
                        ),
                    ),
                ),
                layer_8 = MaxPool((2, 2)),
                layer_9 = ConvNormActivation(
                    model = Chain(
                        block1 = ConvNormActivationBlock(
                            block = Conv((3, 3), 512 => 512, relu, pad=1),  # 2_359_808 parameters
                        ),
                        block2 = ConvNormActivationBlock(
                            block = Conv((3, 3), 512 => 512, relu, pad=1),  # 2_359_808 parameters
                        ),
                    ),
                ),
                layer_10 = MaxPool((2, 2)),
            ),
        ),
        classifier = VGGClassifier(
            model = Chain(
                layer_1 = Lux.FlattenLayer{Nothing}(nothing),
                layer_2 = Dense(25088 => 4096, relu),  # 102_764_544 parameters
                layer_3 = Dropout(0.5),
                layer_4 = Dense(4096 => 4096, relu),  # 16_781_312 parameters
                layer_5 = Dropout(0.5),
                layer_6 = Dense(4096 => 1000),  # 4_097_000 parameters
            ),
        ),
    ),
)         # Total: 133_047_848 parameters,
          #        plus 4 states.

Loading Models from Metalhead (Flux.jl)

We can load models from Metalhead (Flux.jl), just remember to load Metalhead before.

julia
using Metalhead

Vision.ResNet(18)
MetalheadWrapperLayer(
    layer = Chain(
        layer_1 = Chain(
            layer_1 = Chain(
                layer_1 = Conv((7, 7), 3 => 64, pad=3, stride=2, use_bias=false),  # 9_408 parameters
                layer_2 = BatchNorm(64, relu, affine=true, track_stats=true),  # 128 parameters, plus 129
                layer_3 = MaxPool((3, 3), pad=1, stride=2),
            ),
            layer_2 = Chain(
                layer_1 = Parallel(
                    connection = addact(NNlib.relu, ...),
                    layer_1 = Lux.NoOpLayer(),
                    layer_2 = Chain(
                        layer_1 = Conv((3, 3), 64 => 64, pad=1, use_bias=false),  # 36_864 parameters
                        layer_2 = BatchNorm(64, affine=true, track_stats=true),  # 128 parameters, plus 129
                        layer_3 = WrappedFunction(relu),
                        layer_4 = Conv((3, 3), 64 => 64, pad=1, use_bias=false),  # 36_864 parameters
                        layer_5 = BatchNorm(64, affine=true, track_stats=true),  # 128 parameters, plus 129
                    ),
                ),
                layer_2 = Parallel(
                    connection = addact(NNlib.relu, ...),
                    layer_1 = Lux.NoOpLayer(),
                    layer_2 = Chain(
                        layer_1 = Conv((3, 3), 64 => 64, pad=1, use_bias=false),  # 36_864 parameters
                        layer_2 = BatchNorm(64, affine=true, track_stats=true),  # 128 parameters, plus 129
                        layer_3 = WrappedFunction(relu),
                        layer_4 = Conv((3, 3), 64 => 64, pad=1, use_bias=false),  # 36_864 parameters
                        layer_5 = BatchNorm(64, affine=true, track_stats=true),  # 128 parameters, plus 129
                    ),
                ),
            ),
            layer_3 = Chain(
                layer_1 = Parallel(
                    connection = addact(NNlib.relu, ...),
                    layer_1 = Chain(
                        layer_1 = Conv((1, 1), 64 => 128, stride=2, use_bias=false),  # 8_192 parameters
                        layer_2 = BatchNorm(128, affine=true, track_stats=true),  # 256 parameters, plus 257
                    ),
                    layer_2 = Chain(
                        layer_1 = Conv((3, 3), 64 => 128, pad=1, stride=2, use_bias=false),  # 73_728 parameters
                        layer_2 = BatchNorm(128, affine=true, track_stats=true),  # 256 parameters, plus 257
                        layer_3 = WrappedFunction(relu),
                        layer_4 = Conv((3, 3), 128 => 128, pad=1, use_bias=false),  # 147_456 parameters
                        layer_5 = BatchNorm(128, affine=true, track_stats=true),  # 256 parameters, plus 257
                    ),
                ),
                layer_2 = Parallel(
                    connection = addact(NNlib.relu, ...),
                    layer_1 = Lux.NoOpLayer(),
                    layer_2 = Chain(
                        layer_1 = Conv((3, 3), 128 => 128, pad=1, use_bias=false),  # 147_456 parameters
                        layer_2 = BatchNorm(128, affine=true, track_stats=true),  # 256 parameters, plus 257
                        layer_3 = WrappedFunction(relu),
                        layer_4 = Conv((3, 3), 128 => 128, pad=1, use_bias=false),  # 147_456 parameters
                        layer_5 = BatchNorm(128, affine=true, track_stats=true),  # 256 parameters, plus 257
                    ),
                ),
            ),
            layer_4 = Chain(
                layer_1 = Parallel(
                    connection = addact(NNlib.relu, ...),
                    layer_1 = Chain(
                        layer_1 = Conv((1, 1), 128 => 256, stride=2, use_bias=false),  # 32_768 parameters
                        layer_2 = BatchNorm(256, affine=true, track_stats=true),  # 512 parameters, plus 513
                    ),
                    layer_2 = Chain(
                        layer_1 = Conv((3, 3), 128 => 256, pad=1, stride=2, use_bias=false),  # 294_912 parameters
                        layer_2 = BatchNorm(256, affine=true, track_stats=true),  # 512 parameters, plus 513
                        layer_3 = WrappedFunction(relu),
                        layer_4 = Conv((3, 3), 256 => 256, pad=1, use_bias=false),  # 589_824 parameters
                        layer_5 = BatchNorm(256, affine=true, track_stats=true),  # 512 parameters, plus 513
                    ),
                ),
                layer_2 = Parallel(
                    connection = addact(NNlib.relu, ...),
                    layer_1 = Lux.NoOpLayer(),
                    layer_2 = Chain(
                        layer_1 = Conv((3, 3), 256 => 256, pad=1, use_bias=false),  # 589_824 parameters
                        layer_2 = BatchNorm(256, affine=true, track_stats=true),  # 512 parameters, plus 513
                        layer_3 = WrappedFunction(relu),
                        layer_4 = Conv((3, 3), 256 => 256, pad=1, use_bias=false),  # 589_824 parameters
                        layer_5 = BatchNorm(256, affine=true, track_stats=true),  # 512 parameters, plus 513
                    ),
                ),
            ),
            layer_5 = Chain(
                layer_1 = Parallel(
                    connection = addact(NNlib.relu, ...),
                    layer_1 = Chain(
                        layer_1 = Conv((1, 1), 256 => 512, stride=2, use_bias=false),  # 131_072 parameters
                        layer_2 = BatchNorm(512, affine=true, track_stats=true),  # 1_024 parameters, plus 1_025
                    ),
                    layer_2 = Chain(
                        layer_1 = Conv((3, 3), 256 => 512, pad=1, stride=2, use_bias=false),  # 1_179_648 parameters
                        layer_2 = BatchNorm(512, affine=true, track_stats=true),  # 1_024 parameters, plus 1_025
                        layer_3 = WrappedFunction(relu),
                        layer_4 = Conv((3, 3), 512 => 512, pad=1, use_bias=false),  # 2_359_296 parameters
                        layer_5 = BatchNorm(512, affine=true, track_stats=true),  # 1_024 parameters, plus 1_025
                    ),
                ),
                layer_2 = Parallel(
                    connection = addact(NNlib.relu, ...),
                    layer_1 = Lux.NoOpLayer(),
                    layer_2 = Chain(
                        layer_1 = Conv((3, 3), 512 => 512, pad=1, use_bias=false),  # 2_359_296 parameters
                        layer_2 = BatchNorm(512, affine=true, track_stats=true),  # 1_024 parameters, plus 1_025
                        layer_3 = WrappedFunction(relu),
                        layer_4 = Conv((3, 3), 512 => 512, pad=1, use_bias=false),  # 2_359_296 parameters
                        layer_5 = BatchNorm(512, affine=true, track_stats=true),  # 1_024 parameters, plus 1_025
                    ),
                ),
            ),
        ),
        layer_2 = Chain(
            layer_1 = AdaptiveMeanPool((1, 1)),
            layer_2 = WrappedFunction(flatten),
            layer_3 = Dense(512 => 1000),  # 513_000 parameters
        ),
    ),
)         # Total: 11_689_512 parameters,
          #        plus 9_620 states.

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.4
Commit 8561cc3d68d (2025-03-10 11:36 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)
Environment:
  JULIA_NUM_THREADS = 1
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.