Getting Started
Prerequisites
Here we assume that you are familiar with Lux.jl
. If not please take a look at the Lux.jl tutoials.
Boltz.jl
is just like Lux.jl
but comes with more "batteries included". Let's start by defining an MLP model.
using Lux, Boltz, Random
Precompiling Lux...
996.2 ms ✓ Functors
1255.6 ms ✓ LuxCore
1415.7 ms ✓ MLDataDevices
1039.7 ms ✓ LuxCore → LuxCoreFunctorsExt
1069.9 ms ✓ LuxCore → LuxCoreEnzymeCoreExt
1133.8 ms ✓ LuxCore → LuxCoreChainRulesCoreExt
1860.0 ms ✓ Optimisers
1034.8 ms ✓ LuxCore → LuxCoreSetfieldExt
1096.7 ms ✓ MLDataDevices → MLDataDevicesChainRulesCoreExt
975.9 ms ✓ MLDataDevices → MLDataDevicesFillArraysExt
2222.8 ms ✓ MLDataDevices → MLDataDevicesGPUArraysExt
1443.4 ms ✓ MLDataDevices → MLDataDevicesSparseArraysExt
3348.9 ms ✓ MLDataDevices → MLDataDevicesOneHotArraysExt
983.0 ms ✓ LuxCore → LuxCoreMLDataDevicesExt
1402.5 ms ✓ MLDataDevices → MLDataDevicesChainRulesExt
2694.0 ms ✓ MLDataDevices → MLDataDevicesZygoteExt
3355.4 ms ✓ MLDataDevices → MLDataDevicesMLUtilsExt
5882.9 ms ✓ LuxLib
9079.3 ms ✓ Lux
2818.5 ms ✓ Lux → LuxMLUtilsExt
3232.8 ms ✓ Lux → LuxZygoteExt
21 dependencies successfully precompiled in 26 seconds. 190 already precompiled.
Precompiling Boltz...
6425.5 ms ✓ Flux
3514.6 ms ✓ Lux → LuxFluxExt
5240.1 ms ✓ Boltz
2734.8 ms ✓ Boltz → BoltzZygoteExt
4 dependencies successfully precompiled in 18 seconds. 212 already precompiled.
Multi-Layer Perceptron
If we were to do this in Lux.jl
we would write the following:
model = Chain(
Dense(784, 256, relu),
Dense(256, 10)
)
Chain(
layer_1 = Dense(784 => 256, relu), # 200_960 parameters
layer_2 = Dense(256 => 10), # 2_570 parameters
) # Total: 203_530 parameters,
# plus 0 states.
But in Boltz.jl
we can do this:
model = Layers.MLP(784, (256, 10), relu)
MLP(
chain = Chain(
block1 = DenseNormActDropoutBlock(
block = Chain(
dense = Dense(784 => 256, relu), # 200_960 parameters
),
),
block2 = DenseNormActDropoutBlock(
block = Chain(
dense = Dense(256 => 10), # 2_570 parameters
),
),
),
) # Total: 203_530 parameters,
# plus 0 states.
The MLP
function is just a convenience wrapper around Lux.Chain
that constructs a multi-layer perceptron with the given number of layers and activation function.
How about VGG?
Let's take a look at the Vision
module. We can construct a VGG model with the following code:
Vision.VGG(13)
VGG(
layer = Chain(
feature_extractor = VGGFeatureExtractor(
model = Chain(
layer_1 = ConvNormActivation(
model = Chain(
block1 = ConvNormActivationBlock(
block = Conv((3, 3), 3 => 64, relu, pad=1), # 1_792 parameters
),
block2 = ConvNormActivationBlock(
block = Conv((3, 3), 64 => 64, relu, pad=1), # 36_928 parameters
),
),
),
layer_2 = MaxPool((2, 2)),
layer_3 = ConvNormActivation(
model = Chain(
block1 = ConvNormActivationBlock(
block = Conv((3, 3), 64 => 128, relu, pad=1), # 73_856 parameters
),
block2 = ConvNormActivationBlock(
block = Conv((3, 3), 128 => 128, relu, pad=1), # 147_584 parameters
),
),
),
layer_4 = MaxPool((2, 2)),
layer_5 = ConvNormActivation(
model = Chain(
block1 = ConvNormActivationBlock(
block = Conv((3, 3), 128 => 256, relu, pad=1), # 295_168 parameters
),
block2 = ConvNormActivationBlock(
block = Conv((3, 3), 256 => 256, relu, pad=1), # 590_080 parameters
),
),
),
layer_6 = MaxPool((2, 2)),
layer_7 = ConvNormActivation(
model = Chain(
block1 = ConvNormActivationBlock(
block = Conv((3, 3), 256 => 512, relu, pad=1), # 1_180_160 parameters
),
block2 = ConvNormActivationBlock(
block = Conv((3, 3), 512 => 512, relu, pad=1), # 2_359_808 parameters
),
),
),
layer_8 = MaxPool((2, 2)),
layer_9 = ConvNormActivation(
model = Chain(
block1 = ConvNormActivationBlock(
block = Conv((3, 3), 512 => 512, relu, pad=1), # 2_359_808 parameters
),
block2 = ConvNormActivationBlock(
block = Conv((3, 3), 512 => 512, relu, pad=1), # 2_359_808 parameters
),
),
),
layer_10 = MaxPool((2, 2)),
),
),
classifier = VGGClassifier(
model = Chain(
layer_1 = Lux.FlattenLayer{Nothing}(nothing),
layer_2 = Dense(25088 => 4096, relu), # 102_764_544 parameters
layer_3 = Dropout(0.5),
layer_4 = Dense(4096 => 4096, relu), # 16_781_312 parameters
layer_5 = Dropout(0.5),
layer_6 = Dense(4096 => 1000), # 4_097_000 parameters
),
),
),
) # Total: 133_047_848 parameters,
# plus 4 states.
We can also load pretrained ImageNet weights using
Load JLD2
You need to load JLD2
before being able to load pretrained weights.
using JLD2
Vision.VGG(13; pretrained=true)
VGG(
layer = Chain(
feature_extractor = VGGFeatureExtractor(
model = Chain(
layer_1 = ConvNormActivation(
model = Chain(
block1 = ConvNormActivationBlock(
block = Conv((3, 3), 3 => 64, relu, pad=1), # 1_792 parameters
),
block2 = ConvNormActivationBlock(
block = Conv((3, 3), 64 => 64, relu, pad=1), # 36_928 parameters
),
),
),
layer_2 = MaxPool((2, 2)),
layer_3 = ConvNormActivation(
model = Chain(
block1 = ConvNormActivationBlock(
block = Conv((3, 3), 64 => 128, relu, pad=1), # 73_856 parameters
),
block2 = ConvNormActivationBlock(
block = Conv((3, 3), 128 => 128, relu, pad=1), # 147_584 parameters
),
),
),
layer_4 = MaxPool((2, 2)),
layer_5 = ConvNormActivation(
model = Chain(
block1 = ConvNormActivationBlock(
block = Conv((3, 3), 128 => 256, relu, pad=1), # 295_168 parameters
),
block2 = ConvNormActivationBlock(
block = Conv((3, 3), 256 => 256, relu, pad=1), # 590_080 parameters
),
),
),
layer_6 = MaxPool((2, 2)),
layer_7 = ConvNormActivation(
model = Chain(
block1 = ConvNormActivationBlock(
block = Conv((3, 3), 256 => 512, relu, pad=1), # 1_180_160 parameters
),
block2 = ConvNormActivationBlock(
block = Conv((3, 3), 512 => 512, relu, pad=1), # 2_359_808 parameters
),
),
),
layer_8 = MaxPool((2, 2)),
layer_9 = ConvNormActivation(
model = Chain(
block1 = ConvNormActivationBlock(
block = Conv((3, 3), 512 => 512, relu, pad=1), # 2_359_808 parameters
),
block2 = ConvNormActivationBlock(
block = Conv((3, 3), 512 => 512, relu, pad=1), # 2_359_808 parameters
),
),
),
layer_10 = MaxPool((2, 2)),
),
),
classifier = VGGClassifier(
model = Chain(
layer_1 = Lux.FlattenLayer{Nothing}(nothing),
layer_2 = Dense(25088 => 4096, relu), # 102_764_544 parameters
layer_3 = Dropout(0.5),
layer_4 = Dense(4096 => 4096, relu), # 16_781_312 parameters
layer_5 = Dropout(0.5),
layer_6 = Dense(4096 => 1000), # 4_097_000 parameters
),
),
),
) # Total: 133_047_848 parameters,
# plus 4 states.
Loading Models from Metalhead (Flux.jl)
We can load models from Metalhead (Flux.jl), just remember to load Metalhead
before.
using Metalhead
Vision.ResNet(18)
MetalheadWrapperLayer(
layer = Chain(
layer_1 = Chain(
layer_1 = Chain(
layer_1 = Conv((7, 7), 3 => 64, pad=3, stride=2, use_bias=false), # 9_408 parameters
layer_2 = BatchNorm(64, relu, affine=true, track_stats=true), # 128 parameters, plus 129
layer_3 = MaxPool((3, 3), pad=1, stride=2),
),
layer_2 = Chain(
layer_1 = Parallel(
connection = addact(NNlib.relu, ...),
layer_1 = Lux.NoOpLayer(),
layer_2 = Chain(
layer_1 = Conv((3, 3), 64 => 64, pad=1, use_bias=false), # 36_864 parameters
layer_2 = BatchNorm(64, affine=true, track_stats=true), # 128 parameters, plus 129
layer_3 = WrappedFunction(relu),
layer_4 = Conv((3, 3), 64 => 64, pad=1, use_bias=false), # 36_864 parameters
layer_5 = BatchNorm(64, affine=true, track_stats=true), # 128 parameters, plus 129
),
),
layer_2 = Parallel(
connection = addact(NNlib.relu, ...),
layer_1 = Lux.NoOpLayer(),
layer_2 = Chain(
layer_1 = Conv((3, 3), 64 => 64, pad=1, use_bias=false), # 36_864 parameters
layer_2 = BatchNorm(64, affine=true, track_stats=true), # 128 parameters, plus 129
layer_3 = WrappedFunction(relu),
layer_4 = Conv((3, 3), 64 => 64, pad=1, use_bias=false), # 36_864 parameters
layer_5 = BatchNorm(64, affine=true, track_stats=true), # 128 parameters, plus 129
),
),
),
layer_3 = Chain(
layer_1 = Parallel(
connection = addact(NNlib.relu, ...),
layer_1 = Chain(
layer_1 = Conv((1, 1), 64 => 128, stride=2, use_bias=false), # 8_192 parameters
layer_2 = BatchNorm(128, affine=true, track_stats=true), # 256 parameters, plus 257
),
layer_2 = Chain(
layer_1 = Conv((3, 3), 64 => 128, pad=1, stride=2, use_bias=false), # 73_728 parameters
layer_2 = BatchNorm(128, affine=true, track_stats=true), # 256 parameters, plus 257
layer_3 = WrappedFunction(relu),
layer_4 = Conv((3, 3), 128 => 128, pad=1, use_bias=false), # 147_456 parameters
layer_5 = BatchNorm(128, affine=true, track_stats=true), # 256 parameters, plus 257
),
),
layer_2 = Parallel(
connection = addact(NNlib.relu, ...),
layer_1 = Lux.NoOpLayer(),
layer_2 = Chain(
layer_1 = Conv((3, 3), 128 => 128, pad=1, use_bias=false), # 147_456 parameters
layer_2 = BatchNorm(128, affine=true, track_stats=true), # 256 parameters, plus 257
layer_3 = WrappedFunction(relu),
layer_4 = Conv((3, 3), 128 => 128, pad=1, use_bias=false), # 147_456 parameters
layer_5 = BatchNorm(128, affine=true, track_stats=true), # 256 parameters, plus 257
),
),
),
layer_4 = Chain(
layer_1 = Parallel(
connection = addact(NNlib.relu, ...),
layer_1 = Chain(
layer_1 = Conv((1, 1), 128 => 256, stride=2, use_bias=false), # 32_768 parameters
layer_2 = BatchNorm(256, affine=true, track_stats=true), # 512 parameters, plus 513
),
layer_2 = Chain(
layer_1 = Conv((3, 3), 128 => 256, pad=1, stride=2, use_bias=false), # 294_912 parameters
layer_2 = BatchNorm(256, affine=true, track_stats=true), # 512 parameters, plus 513
layer_3 = WrappedFunction(relu),
layer_4 = Conv((3, 3), 256 => 256, pad=1, use_bias=false), # 589_824 parameters
layer_5 = BatchNorm(256, affine=true, track_stats=true), # 512 parameters, plus 513
),
),
layer_2 = Parallel(
connection = addact(NNlib.relu, ...),
layer_1 = Lux.NoOpLayer(),
layer_2 = Chain(
layer_1 = Conv((3, 3), 256 => 256, pad=1, use_bias=false), # 589_824 parameters
layer_2 = BatchNorm(256, affine=true, track_stats=true), # 512 parameters, plus 513
layer_3 = WrappedFunction(relu),
layer_4 = Conv((3, 3), 256 => 256, pad=1, use_bias=false), # 589_824 parameters
layer_5 = BatchNorm(256, affine=true, track_stats=true), # 512 parameters, plus 513
),
),
),
layer_5 = Chain(
layer_1 = Parallel(
connection = addact(NNlib.relu, ...),
layer_1 = Chain(
layer_1 = Conv((1, 1), 256 => 512, stride=2, use_bias=false), # 131_072 parameters
layer_2 = BatchNorm(512, affine=true, track_stats=true), # 1_024 parameters, plus 1_025
),
layer_2 = Chain(
layer_1 = Conv((3, 3), 256 => 512, pad=1, stride=2, use_bias=false), # 1_179_648 parameters
layer_2 = BatchNorm(512, affine=true, track_stats=true), # 1_024 parameters, plus 1_025
layer_3 = WrappedFunction(relu),
layer_4 = Conv((3, 3), 512 => 512, pad=1, use_bias=false), # 2_359_296 parameters
layer_5 = BatchNorm(512, affine=true, track_stats=true), # 1_024 parameters, plus 1_025
),
),
layer_2 = Parallel(
connection = addact(NNlib.relu, ...),
layer_1 = Lux.NoOpLayer(),
layer_2 = Chain(
layer_1 = Conv((3, 3), 512 => 512, pad=1, use_bias=false), # 2_359_296 parameters
layer_2 = BatchNorm(512, affine=true, track_stats=true), # 1_024 parameters, plus 1_025
layer_3 = WrappedFunction(relu),
layer_4 = Conv((3, 3), 512 => 512, pad=1, use_bias=false), # 2_359_296 parameters
layer_5 = BatchNorm(512, affine=true, track_stats=true), # 1_024 parameters, plus 1_025
),
),
),
),
layer_2 = Chain(
layer_1 = AdaptiveMeanPool((1, 1)),
layer_2 = WrappedFunction(flatten),
layer_3 = Dense(512 => 1000), # 513_000 parameters
),
),
) # Total: 11_689_512 parameters,
# plus 9_620 states.
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.1
Commit 8f5b7ca12ad (2024-10-16 10:53 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 4 × AMD EPYC 7763 64-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)
Environment:
JULIA_NUM_THREADS = 1
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.