Getting Started
Here we assume that you are familiar with Lux.jl
. If not please take a look at the Lux.jl tutoials.
is just like Lux.jl
but comes with more "batteries included". Let's start by defining an MLP model.
using Lux, Boltz, Random
Multi-Layer Perceptron
If we were to do this in Lux.jl
we would write the following:
model = Chain(Dense(784, 256, relu), Dense(256, 10))
layer_1 = Dense(784 => 256, relu), # 200_960 parameters
layer_2 = Dense(256 => 10), # 2_570 parameters
) # Total: 203_530 parameters,
# plus 0 states.
But in Boltz.jl
we can do this:
model = Layers.MLP(784, (256, 10), relu)
chain = Chain(
block1 = DenseNormActDropoutBlock(
block = Chain(
dense = Dense(784 => 256, relu), # 200_960 parameters
block2 = DenseNormActDropoutBlock(
block = Chain(
dense = Dense(256 => 10), # 2_570 parameters
) # Total: 203_530 parameters,
# plus 0 states.
function is just a convenience wrapper around Lux.Chain
that constructs a multi-layer perceptron with the given number of layers and activation function.
How about VGG?
Let's take a look at the Vision
module. We can construct a VGG model with the following code:
layer = Chain(
feature_extractor = VGGFeatureExtractor(
model = Chain(
layer_1 = ConvNormActivation(
model = Chain(
block1 = ConvNormActivationBlock(
block = Conv((3, 3), 3 => 64, relu, pad=1), # 1_792 parameters
block2 = ConvNormActivationBlock(
block = Conv((3, 3), 64 => 64, relu, pad=1), # 36_928 parameters
layer_2 = MaxPool((2, 2)),
layer_3 = ConvNormActivation(
model = Chain(
block1 = ConvNormActivationBlock(
block = Conv((3, 3), 64 => 128, relu, pad=1), # 73_856 parameters
block2 = ConvNormActivationBlock(
block = Conv((3, 3), 128 => 128, relu, pad=1), # 147_584 parameters
layer_4 = MaxPool((2, 2)),
layer_5 = ConvNormActivation(
model = Chain(
block1 = ConvNormActivationBlock(
block = Conv((3, 3), 128 => 256, relu, pad=1), # 295_168 parameters
block2 = ConvNormActivationBlock(
block = Conv((3, 3), 256 => 256, relu, pad=1), # 590_080 parameters
layer_6 = MaxPool((2, 2)),
layer_7 = ConvNormActivation(
model = Chain(
block1 = ConvNormActivationBlock(
block = Conv((3, 3), 256 => 512, relu, pad=1), # 1_180_160 parameters
block2 = ConvNormActivationBlock(
block = Conv((3, 3), 512 => 512, relu, pad=1), # 2_359_808 parameters
layer_8 = MaxPool((2, 2)),
layer_9 = ConvNormActivation(
model = Chain(
block1 = ConvNormActivationBlock(
block = Conv((3, 3), 512 => 512, relu, pad=1), # 2_359_808 parameters
block2 = ConvNormActivationBlock(
block = Conv((3, 3), 512 => 512, relu, pad=1), # 2_359_808 parameters
layer_10 = MaxPool((2, 2)),
classifier = VGGClassifier(
model = Chain(
layer_1 = Lux.FlattenLayer{Nothing}(nothing),
layer_2 = Dense(25088 => 4096, relu), # 102_764_544 parameters
layer_3 = Dropout(0.5),
layer_4 = Dense(4096 => 4096, relu), # 16_781_312 parameters
layer_5 = Dropout(0.5),
layer_6 = Dense(4096 => 1000), # 4_097_000 parameters
) # Total: 133_047_848 parameters,
# plus 4 states.
We can also load pretrained ImageNet weights using
Load JLD2
You need to load JLD2
before being able to load pretrained weights.
using JLD2
Vision.VGG(13; pretrained=true)
layer = Chain(
feature_extractor = VGGFeatureExtractor(
model = Chain(
layer_1 = ConvNormActivation(
model = Chain(
block1 = ConvNormActivationBlock(
block = Conv((3, 3), 3 => 64, relu, pad=1), # 1_792 parameters
block2 = ConvNormActivationBlock(
block = Conv((3, 3), 64 => 64, relu, pad=1), # 36_928 parameters
layer_2 = MaxPool((2, 2)),
layer_3 = ConvNormActivation(
model = Chain(
block1 = ConvNormActivationBlock(
block = Conv((3, 3), 64 => 128, relu, pad=1), # 73_856 parameters
block2 = ConvNormActivationBlock(
block = Conv((3, 3), 128 => 128, relu, pad=1), # 147_584 parameters
layer_4 = MaxPool((2, 2)),
layer_5 = ConvNormActivation(
model = Chain(
block1 = ConvNormActivationBlock(
block = Conv((3, 3), 128 => 256, relu, pad=1), # 295_168 parameters
block2 = ConvNormActivationBlock(
block = Conv((3, 3), 256 => 256, relu, pad=1), # 590_080 parameters
layer_6 = MaxPool((2, 2)),
layer_7 = ConvNormActivation(
model = Chain(
block1 = ConvNormActivationBlock(
block = Conv((3, 3), 256 => 512, relu, pad=1), # 1_180_160 parameters
block2 = ConvNormActivationBlock(
block = Conv((3, 3), 512 => 512, relu, pad=1), # 2_359_808 parameters
layer_8 = MaxPool((2, 2)),
layer_9 = ConvNormActivation(
model = Chain(
block1 = ConvNormActivationBlock(
block = Conv((3, 3), 512 => 512, relu, pad=1), # 2_359_808 parameters
block2 = ConvNormActivationBlock(
block = Conv((3, 3), 512 => 512, relu, pad=1), # 2_359_808 parameters
layer_10 = MaxPool((2, 2)),
classifier = VGGClassifier(
model = Chain(
layer_1 = Lux.FlattenLayer{Nothing}(nothing),
layer_2 = Dense(25088 => 4096, relu), # 102_764_544 parameters
layer_3 = Dropout(0.5),
layer_4 = Dense(4096 => 4096, relu), # 16_781_312 parameters
layer_5 = Dropout(0.5),
layer_6 = Dense(4096 => 1000), # 4_097_000 parameters
) # Total: 133_047_848 parameters,
# plus 4 states.
Loading Models from Metalhead (Flux.jl)
We can load models from Metalhead (Flux.jl), just remember to load Metalhead
using Metalhead
layer = Chain(
layer_1 = Chain(
layer_1 = Chain(
layer_1 = Conv((7, 7), 3 => 64, pad=3, stride=2, use_bias=false), # 9_408 parameters
layer_2 = BatchNorm(64, relu, affine=true, track_stats=true), # 128 parameters, plus 129
layer_3 = MaxPool((3, 3), pad=1, stride=2),
layer_2 = Chain(
layer_1 = Parallel(
connection = addact(NNlib.relu, ...),
layer_1 = Lux.NoOpLayer(),
layer_2 = Chain(
layer_1 = Conv((3, 3), 64 => 64, pad=1, use_bias=false), # 36_864 parameters
layer_2 = BatchNorm(64, affine=true, track_stats=true), # 128 parameters, plus 129
layer_3 = WrappedFunction(relu),
layer_4 = Conv((3, 3), 64 => 64, pad=1, use_bias=false), # 36_864 parameters
layer_5 = BatchNorm(64, affine=true, track_stats=true), # 128 parameters, plus 129
layer_2 = Parallel(
connection = addact(NNlib.relu, ...),
layer_1 = Lux.NoOpLayer(),
layer_2 = Chain(
layer_1 = Conv((3, 3), 64 => 64, pad=1, use_bias=false), # 36_864 parameters
layer_2 = BatchNorm(64, affine=true, track_stats=true), # 128 parameters, plus 129
layer_3 = WrappedFunction(relu),
layer_4 = Conv((3, 3), 64 => 64, pad=1, use_bias=false), # 36_864 parameters
layer_5 = BatchNorm(64, affine=true, track_stats=true), # 128 parameters, plus 129
layer_3 = Chain(
layer_1 = Parallel(
connection = addact(NNlib.relu, ...),
layer_1 = Chain(
layer_1 = Conv((1, 1), 64 => 128, stride=2, use_bias=false), # 8_192 parameters
layer_2 = BatchNorm(128, affine=true, track_stats=true), # 256 parameters, plus 257
layer_2 = Chain(
layer_1 = Conv((3, 3), 64 => 128, pad=1, stride=2, use_bias=false), # 73_728 parameters
layer_2 = BatchNorm(128, affine=true, track_stats=true), # 256 parameters, plus 257
layer_3 = WrappedFunction(relu),
layer_4 = Conv((3, 3), 128 => 128, pad=1, use_bias=false), # 147_456 parameters
layer_5 = BatchNorm(128, affine=true, track_stats=true), # 256 parameters, plus 257
layer_2 = Parallel(
connection = addact(NNlib.relu, ...),
layer_1 = Lux.NoOpLayer(),
layer_2 = Chain(
layer_1 = Conv((3, 3), 128 => 128, pad=1, use_bias=false), # 147_456 parameters
layer_2 = BatchNorm(128, affine=true, track_stats=true), # 256 parameters, plus 257
layer_3 = WrappedFunction(relu),
layer_4 = Conv((3, 3), 128 => 128, pad=1, use_bias=false), # 147_456 parameters
layer_5 = BatchNorm(128, affine=true, track_stats=true), # 256 parameters, plus 257
layer_4 = Chain(
layer_1 = Parallel(
connection = addact(NNlib.relu, ...),
layer_1 = Chain(
layer_1 = Conv((1, 1), 128 => 256, stride=2, use_bias=false), # 32_768 parameters
layer_2 = BatchNorm(256, affine=true, track_stats=true), # 512 parameters, plus 513
layer_2 = Chain(
layer_1 = Conv((3, 3), 128 => 256, pad=1, stride=2, use_bias=false), # 294_912 parameters
layer_2 = BatchNorm(256, affine=true, track_stats=true), # 512 parameters, plus 513
layer_3 = WrappedFunction(relu),
layer_4 = Conv((3, 3), 256 => 256, pad=1, use_bias=false), # 589_824 parameters
layer_5 = BatchNorm(256, affine=true, track_stats=true), # 512 parameters, plus 513
layer_2 = Parallel(
connection = addact(NNlib.relu, ...),
layer_1 = Lux.NoOpLayer(),
layer_2 = Chain(
layer_1 = Conv((3, 3), 256 => 256, pad=1, use_bias=false), # 589_824 parameters
layer_2 = BatchNorm(256, affine=true, track_stats=true), # 512 parameters, plus 513
layer_3 = WrappedFunction(relu),
layer_4 = Conv((3, 3), 256 => 256, pad=1, use_bias=false), # 589_824 parameters
layer_5 = BatchNorm(256, affine=true, track_stats=true), # 512 parameters, plus 513
layer_5 = Chain(
layer_1 = Parallel(
connection = addact(NNlib.relu, ...),
layer_1 = Chain(
layer_1 = Conv((1, 1), 256 => 512, stride=2, use_bias=false), # 131_072 parameters
layer_2 = BatchNorm(512, affine=true, track_stats=true), # 1_024 parameters, plus 1_025
layer_2 = Chain(
layer_1 = Conv((3, 3), 256 => 512, pad=1, stride=2, use_bias=false), # 1_179_648 parameters
layer_2 = BatchNorm(512, affine=true, track_stats=true), # 1_024 parameters, plus 1_025
layer_3 = WrappedFunction(relu),
layer_4 = Conv((3, 3), 512 => 512, pad=1, use_bias=false), # 2_359_296 parameters
layer_5 = BatchNorm(512, affine=true, track_stats=true), # 1_024 parameters, plus 1_025
layer_2 = Parallel(
connection = addact(NNlib.relu, ...),
layer_1 = Lux.NoOpLayer(),
layer_2 = Chain(
layer_1 = Conv((3, 3), 512 => 512, pad=1, use_bias=false), # 2_359_296 parameters
layer_2 = BatchNorm(512, affine=true, track_stats=true), # 1_024 parameters, plus 1_025
layer_3 = WrappedFunction(relu),
layer_4 = Conv((3, 3), 512 => 512, pad=1, use_bias=false), # 2_359_296 parameters
layer_5 = BatchNorm(512, affine=true, track_stats=true), # 1_024 parameters, plus 1_025
layer_2 = Chain(
layer_1 = AdaptiveMeanPool((1, 1)),
layer_2 = WrappedFunction(flatten),
layer_3 = Dense(512 => 1000), # 513_000 parameters
) # Total: 11_689_512 parameters,
# plus 9_620 states.
using InteractiveUtils
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
Julia Version 1.11.4
Commit 8561cc3d68d (2025-03-10 11:36 UTC)
Build Info:
Official release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 4 × AMD EPYC 7763 64-Core Processor
LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)
JULIA_DEBUG = Literate
This page was generated using Literate.jl.