Skip to content
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
42726a1
Initial commit for MLP mixer
theabhirath Jan 29, 2022
0ad2db9
Updated directory structure
theabhirath Jan 29, 2022
e2a3e4a
Rename test/ConvNets.jl to test/convnets.jl
theabhirath Jan 29, 2022
f4e71b9
Initial commit for MLP mixer
theabhirath Jan 29, 2022
4a3d82f
Updated directory structure
theabhirath Jan 29, 2022
4116add
Rename test/ConvNets.jl to test/convnets.jl
theabhirath Jan 29, 2022
0402ad9
Merge branch 'mlpmixer' of https://github.com/theabhirath/Metalhead.j…
theabhirath Jan 29, 2022
3c61e4b
Update runtests.jl
theabhirath Jan 29, 2022
3b58716
Initial commit for MLP mixer
theabhirath Jan 29, 2022
7a72ef4
Updated directory structure
theabhirath Jan 29, 2022
bbf0cdf
Rename test/ConvNets.jl to test/convnets.jl
theabhirath Jan 29, 2022
bb49697
Initial commit for MLP mixer
theabhirath Jan 29, 2022
b5382dd
Updated directory structure
theabhirath Jan 29, 2022
de7ebf1
Rename test/ConvNets.jl to test/convnets.jl
theabhirath Jan 29, 2022
d2ff28b
Update runtests.jl
theabhirath Jan 29, 2022
624c539
Updated MLPMixer category
theabhirath Jan 30, 2022
7adccec
Merge branch 'mlpmixer' of https://github.com/theabhirath/Metalhead.j…
theabhirath Jan 30, 2022
d7933b1
Clean up files
theabhirath Jan 31, 2022
ef6030c
Trimmed struct definition for MLPMixer model
theabhirath Feb 2, 2022
3b7e421
Cleaned up MLPMixer implementation
theabhirath Feb 3, 2022
04c78c6
Update Metalhead.jl
theabhirath Feb 3, 2022
eb45dee
Cleaned up API for model
theabhirath Feb 3, 2022
cabe8aa
Minor formatting tweaks
theabhirath Feb 4, 2022
44de174
Apply suggestions from code review
darsnack Feb 4, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,18 @@ BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
BSON = "0.3.2"
Flux = "0.12"
Functors = "0.2"
julia = "1.4"
NNlib = "0.7.34"
julia = "1.4"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"

[publish]
title = "Metalhead.jl"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
| ResNeXt-152 | [`ResNeXt`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.ResNeXt.html) | N |
| [MobileNetv2](https://arxiv.org/abs/1801.04381) | [`MobileNetv2`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.MobileNetv2.html) | N |
| [MobileNetv3](https://arxiv.org/abs/1905.02244) | [`MobileNetv3`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.MobileNetv3.html) | N |

| [MLPMixer](https://arxiv.org/pdf/2105.01601) | [`MLPMixer`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.MLPMixer.html) | N |

## Getting Started

Expand Down
33 changes: 20 additions & 13 deletions src/Metalhead.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,40 @@ using Flux: outputsize, Zygote
using Functors
using BSON
using Artifacts, LazyArtifacts
using Statistics

import Functors

# Models
include("utilities.jl")
include("alexnet.jl")
include("vgg.jl")
include("resnet.jl")
include("googlenet.jl")
include("inception.jl")
include("squeezenet.jl")
include("densenet.jl")
include("resnext.jl")
include("mobilenet.jl")
include("layers.jl")

# CNN models
include("convnets/alexnet.jl")
include("convnets/vgg.jl")
include("convnets/inception.jl")
include("convnets/googlenet.jl")
include("convnets/resnet.jl")
include("convnets/resnext.jl")
include("convnets/densenet.jl")
include("convnets/squeezenet.jl")
include("convnets/mobilenet.jl")

# Other models
include("other/mlpmixer.jl")

export AlexNet,
VGG, VGG11, VGG13, VGG16, VGG19,
ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152,
GoogLeNet, Inception3, SqueezeNet,
DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201,
ResNeXt,
MobileNetv2, MobileNetv3
MobileNetv2, MobileNetv3,
MLPMixer

# use Flux._big_show to pretty print large models
for T in (:AlexNet, :VGG, :ResNet, :GoogLeNet, :Inception3, :SqueezeNet, :DenseNet, :ResNeXt,
:MobileNetv2, :MobileNetv3)
@eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model)
:MobileNetv2, :MobileNetv3, :MLPMixer)
@eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model)
end

end # module
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
126 changes: 126 additions & 0 deletions src/layers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""
conv_bn(kernelsize, inplanes, outplanes, activation = relu;
rev = false,
stride = 1, pad = 0, dilation = 1, groups = 1, [bias, weight, init],
initβ = Flux.zeros32, initγ = Flux.ones32, ϵ = 1f-5, momentum = 1f-1)

Create a convolution + batch normalization pair with ReLU activation.

# Arguments
- `kernelsize`: size of the convolution kernel (tuple)
- `inplanes`: number of input feature maps
- `outplanes`: number of output feature maps
- `activation`: the activation function for the final layer
- `rev`: set to `true` to place the batch norm before the convolution
- `stride`: stride of the convolution kernel
- `pad`: padding of the convolution kernel
- `dilation`: dilation of the convolution kernel
- `groups`: groups for the convolution kernel
- `bias`, `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#))
- `initβ`, `initγ`: initialization for the batch norm (see [`Flux.BatchNorm`](#))
- `ϵ`, `momentum`: batch norm parameters (see [`Flux.BatchNorm`](#))
"""
function conv_bn(kernelsize, inplanes, outplanes, activation = relu;
rev = false,
initβ = Flux.zeros32, initγ = Flux.ones32, ϵ = 1f-5, momentum = 1f-1,
kwargs...)
layers = []

if rev
activations = (conv = activation, bn = identity)
bnplanes = inplanes
else
activations = (conv = identity, bn = activation)
bnplanes = outplanes
end

push!(layers, Conv(kernelsize, Int(inplanes) => Int(outplanes), activations.conv; kwargs...))
push!(layers, BatchNorm(Int(bnplanes), activations.bn;
initβ = initβ, initγ = initγ, ϵ = ϵ, momentum = momentum))

return rev ? reverse(layers) : layers
end

"""
cat_channels(x, y)

Concatenate `x` and `y` along the channel dimension (third dimension).
Equivalent to `cat(x, y; dims=3)`.
Convenient binary reduction operator for use with `Parallel`.
"""
cat_channels(x, y) = cat(x, y; dims = 3)

"""
skip_projection(inplanes, outplanes, downsample = false)

Create a skip projection
([reference](https://arxiv.org/abs/1512.03385v1)).

# Arguments:
- `inplanes`: the number of input feature maps
- `outplanes`: the number of output feature maps
- `downsample`: set to `true` to downsample the input
"""
skip_projection(inplanes, outplanes, downsample = false) = downsample ?
Chain(conv_bn((1, 1), inplanes, outplanes, identity; stride = 2, bias = false)...) :
Chain(conv_bn((1, 1), inplanes, outplanes, identity; stride = 1, bias = false)...)

# array -> PaddedView(0, array, outplanes) for zero padding arrays
"""
skip_identity(inplanes, outplanes[, downsample])

Create a identity projection
([reference](https://arxiv.org/abs/1512.03385v1)).

# Arguments:
- `inplanes`: the number of input feature maps
- `outplanes`: the number of output feature maps
- `downsample`: this argument is ignored but it is needed for compatibility with [`resnet`](#).
"""
function skip_identity(inplanes, outplanes)
if outplanes > inplanes
return Chain(MaxPool((1, 1), stride = 2),
y -> cat(y, zeros(eltype(y),
size(y, 1),
size(y, 2),
outplanes - inplanes, size(y, 4)); dims = 3))
else
return identity
end
end
skip_identity(inplanes, outplanes, downsample) = skip_identity(inplanes, outplanes)

# Patching layer used by many vision transformer-like models
struct Patching{T <: Integer}
patch_height::T
patch_width::T
end
Patching(patch_size) = Patching(patch_size, patch_size)

function (p::Patching)(x)
h, w, c, n = size(x)
hp, wp = h ÷ p.patch_height, w ÷ p.patch_width
xpatch = reshape(x, hp, p.patch_height, wp, p.patch_width, c, n)

return reshape(permutedims(xpatch, (1, 3, 5, 2, 4, 6)), p.patch_height * p.patch_width * c,
hp * wp, n)
end

@functor Patching

"""
mlpblock(planes, expansion_factor = 4, dropout = 0., dense = Dense)

Feedforward block used in many vision transformer-like models.

# Arguments
`planes`: Number of dimensions in the input and output.
`hidden_planes`: Number of dimensions in the intermediate layer.
`dropout`: Dropout rate.
`dense`: Type of dense layer to use in the feedforward block.
`activation`: Activation function to use.
"""
function mlpblock(planes, hidden_planes, dropout = 0., dense = Dense; activation = gelu)
Chain(dense(planes, hidden_planes, activation), Dropout(dropout),
dense(hidden_planes, planes, activation), Dropout(dropout))
end
86 changes: 86 additions & 0 deletions src/other/mlpmixer.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Utility function for creating a residual block with LayerNorm before the residual connection
_residualprenorm(planes, fn) = SkipConnection(Chain(fn, LayerNorm(planes)), +)

# Utility function for 1D convolution
_conv1d(inplanes, outplanes, activation) = Conv((1, ), inplanes => outplanes, activation)

"""
mlpmixer(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = 16, planes = 512,
depth = 12, expansion_factor = 4, dropout = 0., nclasses = 1000, token_mix =
_conv1d, channel_mix = Dense))
Creates a model with the MLPMixer architecture.
([reference](https://arxiv.org/pdf/2105.01601)).
# Arguments
- imsize: the size of the input image
- inchannels: the number of input channels
- patch_size: the size of the patches
- planes: the number of channels fed into the main model
- depth: the number of blocks in the main model
- expansion_factor: the number of channels in each block
- dropout: the dropout rate
- nclasses: the number of classes in the output
- token_mix: the function to use for the token mixing layer
- channel_mix: the function to use for the channel mixing layer
"""
function mlpmixer(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = 16, planes = 512,
depth = 12, expansion_factor = 4, dropout = 0., nclasses = 1000, token_mix =
_conv1d, channel_mix = Dense)

im_height, im_width = imsize

@assert (im_height % patch_size) == 0 && (im_width % patch_size == 0)
"image size must be divisible by patch size"

num_patches = (im_height ÷ patch_size) * (im_width ÷ patch_size)

layers = []
push!(layers, Patching(patch_size))
push!(layers, Dense((patch_size ^ 2) * inchannels, planes))
append!(layers, [Chain(_residualprenorm(planes, mlpblock(num_patches,
expansion_factor * num_patches,
dropout, token_mix)),
_residualprenorm(planes, mlpblock(planes,
expansion_factor * planes, dropout,
channel_mix)),) for _ in 1:depth])

classification_head = Chain(_seconddimmean, Dense(planes, nclasses))

return Chain(layers..., classification_head)
end

struct MLPMixer
layers
end

"""
MLPMixer(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = 16, planes = 512,
depth = 12, expansion_factor = 4, dropout = 0., pretrain = false, nclasses = 1000)
Creates a model with the MLPMixer architecture.
([reference](https://arxiv.org/pdf/2105.01601)).
# Arguments
- imsize: the size of the input image
- inchannels: the number of input channels
- patch_size: the size of the patches
- planes: the number of channels fed into the main model
- depth: the number of blocks in the main model
- expansion_factor: the number of channels in each block
- dropout: the dropout rate
- pretrain: whether to load the pre-trained weights for ImageNet
- nclasses: the number of classes in the output
"""
function MLPMixer(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = 16, planes = 512,
depth = 12, expansion_factor = 4, dropout = 0., pretrain = false, nclasses = 1000)

layers = mlpmixer(imsize; inchannels, patch_size, planes, depth, expansion_factor, dropout,
nclasses)
pretrain && loadpretrain!(layers, string("MLPMixer"))
MLPMixer(layers)
end

(m::MLPMixer)(x) = m.layers(x)

@functor MLPMixer
93 changes: 2 additions & 91 deletions src/utilities.jl
Original file line number Diff line number Diff line change
@@ -1,94 +1,5 @@
"""
conv_bn(kernelsize, inplanes, outplanes, activation = relu;
rev = false,
stride = 1, pad = 0, dilation = 1, groups = 1, [bias, weight, init],
initβ = Flux.zeros32, initγ = Flux.ones32, ϵ = 1f-5, momentum = 1f-1)

Create a convolution + batch normalization pair with ReLU activation.

# Arguments
- `kernelsize`: size of the convolution kernel (tuple)
- `inplanes`: number of input feature maps
- `outplanes`: number of output feature maps
- `activation`: the activation function for the final layer
- `rev`: set to `true` to place the batch norm before the convolution
- `stride`: stride of the convolution kernel
- `pad`: padding of the convolution kernel
- `dilation`: dilation of the convolution kernel
- `groups`: groups for the convolution kernel
- `bias`, `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#))
- `initβ`, `initγ`: initialization for the batch norm (see [`Flux.BatchNorm`](#))
- `ϵ`, `momentum`: batch norm parameters (see [`Flux.BatchNorm`](#))
"""
function conv_bn(kernelsize, inplanes, outplanes, activation = relu;
rev = false,
initβ = Flux.zeros32, initγ = Flux.ones32, ϵ = 1f-5, momentum = 1f-1,
kwargs...)
layers = []

if rev
activations = (conv = activation, bn = identity)
bnplanes = inplanes
else
activations = (conv = identity, bn = activation)
bnplanes = outplanes
end

push!(layers, Conv(kernelsize, Int(inplanes) => Int(outplanes), activations.conv; kwargs...))
push!(layers, BatchNorm(Int(bnplanes), activations.bn;
initβ = initβ, initγ = initγ, ϵ = ϵ, momentum = momentum))

return rev ? reverse(layers) : layers
end

"""
cat_channels(x, y)

Concatenate `x` and `y` along the channel dimension (third dimension).
Equivalent to `cat(x, y; dims=3)`.
Convenient binary reduction operator for use with `Parallel`.
"""
cat_channels(x, y) = cat(x, y; dims = 3)

"""
skip_projection(inplanes, outplanes, downsample = false)

Create a skip projection
([reference](https://arxiv.org/abs/1512.03385v1)).

# Arguments:
- `inplanes`: the number of input feature maps
- `outplanes`: the number of output feature maps
- `downsample`: set to `true` to downsample the input
"""
skip_projection(inplanes, outplanes, downsample = false) = downsample ?
Chain(conv_bn((1, 1), inplanes, outplanes, identity; stride = 2, bias = false)...) :
Chain(conv_bn((1, 1), inplanes, outplanes, identity; stride = 1, bias = false)...)

# array -> PaddedView(0, array, outplanes) for zero padding arrays
"""
skip_identity(inplanes, outplanes[, downsample])

Create a identity projection
([reference](https://arxiv.org/abs/1512.03385v1)).

# Arguments:
- `inplanes`: the number of input feature maps
- `outplanes`: the number of output feature maps
- `downsample`: this argument is ignored but it is needed for compatibility with [`resnet`](#).
"""
function skip_identity(inplanes, outplanes)
if outplanes > inplanes
return Chain(MaxPool((1, 1), stride = 2),
y -> cat(y, zeros(eltype(y),
size(y, 1),
size(y, 2),
outplanes - inplanes, size(y, 4)); dims = 3))
else
return identity
end
end
skip_identity(inplanes, outplanes, downsample) = skip_identity(inplanes, outplanes)
# Utility function for classifier head of vision transformer-like models
_seconddimmean(x) = mean(x, dims = 2)[:, 1, :]

"""
weights(model)
Expand Down
Loading