-
-
Couldn't load subscription status.
- Fork 67
Implementation of MLPMixer #103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 22 commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
42726a1
Initial commit for MLP mixer
theabhirath 0ad2db9
Updated directory structure
theabhirath e2a3e4a
Rename test/ConvNets.jl to test/convnets.jl
theabhirath f4e71b9
Initial commit for MLP mixer
theabhirath 4a3d82f
Updated directory structure
theabhirath 4116add
Rename test/ConvNets.jl to test/convnets.jl
theabhirath 0402ad9
Merge branch 'mlpmixer' of https://github.com/theabhirath/Metalhead.j…
theabhirath 3c61e4b
Update runtests.jl
theabhirath 3b58716
Initial commit for MLP mixer
theabhirath 7a72ef4
Updated directory structure
theabhirath bbf0cdf
Rename test/ConvNets.jl to test/convnets.jl
theabhirath bb49697
Initial commit for MLP mixer
theabhirath b5382dd
Updated directory structure
theabhirath de7ebf1
Rename test/ConvNets.jl to test/convnets.jl
theabhirath d2ff28b
Update runtests.jl
theabhirath 624c539
Updated MLPMixer category
theabhirath 7adccec
Merge branch 'mlpmixer' of https://github.com/theabhirath/Metalhead.j…
theabhirath d7933b1
Clean up files
theabhirath ef6030c
Trimmed struct definition for MLPMixer model
theabhirath 3b7e421
Cleaned up MLPMixer implementation
theabhirath 04c78c6
Update Metalhead.jl
theabhirath eb45dee
Cleaned up API for model
theabhirath cabe8aa
Minor formatting tweaks
theabhirath 44de174
Apply suggestions from code review
darsnack File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,126 @@ | ||
| """ | ||
| conv_bn(kernelsize, inplanes, outplanes, activation = relu; | ||
| rev = false, | ||
| stride = 1, pad = 0, dilation = 1, groups = 1, [bias, weight, init], | ||
| initβ = Flux.zeros32, initγ = Flux.ones32, ϵ = 1f-5, momentum = 1f-1) | ||
|
|
||
| Create a convolution + batch normalization pair with ReLU activation. | ||
|
|
||
| # Arguments | ||
| - `kernelsize`: size of the convolution kernel (tuple) | ||
| - `inplanes`: number of input feature maps | ||
| - `outplanes`: number of output feature maps | ||
| - `activation`: the activation function for the final layer | ||
| - `rev`: set to `true` to place the batch norm before the convolution | ||
| - `stride`: stride of the convolution kernel | ||
| - `pad`: padding of the convolution kernel | ||
| - `dilation`: dilation of the convolution kernel | ||
| - `groups`: groups for the convolution kernel | ||
| - `bias`, `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) | ||
| - `initβ`, `initγ`: initialization for the batch norm (see [`Flux.BatchNorm`](#)) | ||
| - `ϵ`, `momentum`: batch norm parameters (see [`Flux.BatchNorm`](#)) | ||
| """ | ||
| function conv_bn(kernelsize, inplanes, outplanes, activation = relu; | ||
| rev = false, | ||
| initβ = Flux.zeros32, initγ = Flux.ones32, ϵ = 1f-5, momentum = 1f-1, | ||
| kwargs...) | ||
| layers = [] | ||
|
|
||
| if rev | ||
| activations = (conv = activation, bn = identity) | ||
| bnplanes = inplanes | ||
| else | ||
| activations = (conv = identity, bn = activation) | ||
| bnplanes = outplanes | ||
| end | ||
|
|
||
| push!(layers, Conv(kernelsize, Int(inplanes) => Int(outplanes), activations.conv; kwargs...)) | ||
| push!(layers, BatchNorm(Int(bnplanes), activations.bn; | ||
| initβ = initβ, initγ = initγ, ϵ = ϵ, momentum = momentum)) | ||
|
|
||
| return rev ? reverse(layers) : layers | ||
| end | ||
|
|
||
| """ | ||
| cat_channels(x, y) | ||
|
|
||
| Concatenate `x` and `y` along the channel dimension (third dimension). | ||
| Equivalent to `cat(x, y; dims=3)`. | ||
| Convenient binary reduction operator for use with `Parallel`. | ||
| """ | ||
| cat_channels(x, y) = cat(x, y; dims = 3) | ||
|
|
||
| """ | ||
| skip_projection(inplanes, outplanes, downsample = false) | ||
|
|
||
| Create a skip projection | ||
| ([reference](https://arxiv.org/abs/1512.03385v1)). | ||
|
|
||
| # Arguments: | ||
| - `inplanes`: the number of input feature maps | ||
| - `outplanes`: the number of output feature maps | ||
| - `downsample`: set to `true` to downsample the input | ||
| """ | ||
| skip_projection(inplanes, outplanes, downsample = false) = downsample ? | ||
| Chain(conv_bn((1, 1), inplanes, outplanes, identity; stride = 2, bias = false)...) : | ||
| Chain(conv_bn((1, 1), inplanes, outplanes, identity; stride = 1, bias = false)...) | ||
|
|
||
| # array -> PaddedView(0, array, outplanes) for zero padding arrays | ||
| """ | ||
| skip_identity(inplanes, outplanes[, downsample]) | ||
|
|
||
| Create a identity projection | ||
| ([reference](https://arxiv.org/abs/1512.03385v1)). | ||
|
|
||
| # Arguments: | ||
| - `inplanes`: the number of input feature maps | ||
| - `outplanes`: the number of output feature maps | ||
| - `downsample`: this argument is ignored but it is needed for compatibility with [`resnet`](#). | ||
| """ | ||
| function skip_identity(inplanes, outplanes) | ||
| if outplanes > inplanes | ||
| return Chain(MaxPool((1, 1), stride = 2), | ||
| y -> cat(y, zeros(eltype(y), | ||
| size(y, 1), | ||
| size(y, 2), | ||
| outplanes - inplanes, size(y, 4)); dims = 3)) | ||
| else | ||
| return identity | ||
| end | ||
| end | ||
| skip_identity(inplanes, outplanes, downsample) = skip_identity(inplanes, outplanes) | ||
|
|
||
| # Patching layer used by many vision transformer-like models | ||
| struct Patching{T <: Integer} | ||
| patch_height::T | ||
| patch_width::T | ||
| end | ||
| Patching(patch_size) = Patching(patch_size, patch_size) | ||
|
|
||
| function (p::Patching)(x) | ||
| h, w, c, n = size(x) | ||
| hp, wp = h ÷ p.patch_height, w ÷ p.patch_width | ||
| xpatch = reshape(x, hp, p.patch_height, wp, p.patch_width, c, n) | ||
|
|
||
| return reshape(permutedims(xpatch, (1, 3, 5, 2, 4, 6)), p.patch_height * p.patch_width * c, | ||
| hp * wp, n) | ||
| end | ||
|
|
||
| @functor Patching | ||
|
|
||
| """ | ||
| mlpblock(planes, expansion_factor = 4, dropout = 0., dense = Dense) | ||
|
|
||
| Feedforward block used in many vision transformer-like models. | ||
|
|
||
| # Arguments | ||
| `planes`: Number of dimensions in the input and output. | ||
| `hidden_planes`: Number of dimensions in the intermediate layer. | ||
| `dropout`: Dropout rate. | ||
| `dense`: Type of dense layer to use in the feedforward block. | ||
| `activation`: Activation function to use. | ||
| """ | ||
| function mlpblock(planes, hidden_planes, dropout = 0., dense = Dense; activation = gelu) | ||
| Chain(dense(planes, hidden_planes, activation), Dropout(dropout), | ||
| dense(hidden_planes, planes, activation), Dropout(dropout)) | ||
| end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,86 @@ | ||
| # Utility function for creating a residual block with LayerNorm before the residual connection | ||
| _residualprenorm(planes, fn) = SkipConnection(Chain(fn, LayerNorm(planes)), +) | ||
|
|
||
| # Utility function for 1D convolution | ||
| _conv1d(inplanes, outplanes, activation) = Conv((1, ), inplanes => outplanes, activation) | ||
|
|
||
| """ | ||
| mlpmixer(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = 16, planes = 512, | ||
| depth = 12, expansion_factor = 4, dropout = 0., nclasses = 1000, token_mix = | ||
| _conv1d, channel_mix = Dense)) | ||
| Creates a model with the MLPMixer architecture. | ||
| ([reference](https://arxiv.org/pdf/2105.01601)). | ||
| # Arguments | ||
| - imsize: the size of the input image | ||
| - inchannels: the number of input channels | ||
| - patch_size: the size of the patches | ||
| - planes: the number of channels fed into the main model | ||
| - depth: the number of blocks in the main model | ||
| - expansion_factor: the number of channels in each block | ||
| - dropout: the dropout rate | ||
| - nclasses: the number of classes in the output | ||
| - token_mix: the function to use for the token mixing layer | ||
| - channel_mix: the function to use for the channel mixing layer | ||
| """ | ||
| function mlpmixer(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = 16, planes = 512, | ||
| depth = 12, expansion_factor = 4, dropout = 0., nclasses = 1000, token_mix = | ||
| _conv1d, channel_mix = Dense) | ||
|
|
||
| im_height, im_width = imsize | ||
|
|
||
| @assert (im_height % patch_size) == 0 && (im_width % patch_size == 0) | ||
| "image size must be divisible by patch size" | ||
|
|
||
| num_patches = (im_height ÷ patch_size) * (im_width ÷ patch_size) | ||
|
|
||
| layers = [] | ||
| push!(layers, Patching(patch_size)) | ||
| push!(layers, Dense((patch_size ^ 2) * inchannels, planes)) | ||
| append!(layers, [Chain(_residualprenorm(planes, mlpblock(num_patches, | ||
| expansion_factor * num_patches, | ||
| dropout, token_mix)), | ||
| _residualprenorm(planes, mlpblock(planes, | ||
| expansion_factor * planes, dropout, | ||
| channel_mix)),) for _ in 1:depth]) | ||
|
|
||
| classification_head = Chain(_seconddimmean, Dense(planes, nclasses)) | ||
|
|
||
| return Chain(layers..., classification_head) | ||
theabhirath marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| end | ||
|
|
||
| struct MLPMixer | ||
| layers | ||
| end | ||
|
|
||
| """ | ||
| MLPMixer(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = 16, planes = 512, | ||
| depth = 12, expansion_factor = 4, dropout = 0., pretrain = false, nclasses = 1000) | ||
| Creates a model with the MLPMixer architecture. | ||
| ([reference](https://arxiv.org/pdf/2105.01601)). | ||
| # Arguments | ||
| - imsize: the size of the input image | ||
| - inchannels: the number of input channels | ||
| - patch_size: the size of the patches | ||
| - planes: the number of channels fed into the main model | ||
| - depth: the number of blocks in the main model | ||
| - expansion_factor: the number of channels in each block | ||
| - dropout: the dropout rate | ||
| - pretrain: whether to load the pre-trained weights for ImageNet | ||
| - nclasses: the number of classes in the output | ||
| """ | ||
| function MLPMixer(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = 16, planes = 512, | ||
| depth = 12, expansion_factor = 4, dropout = 0., pretrain = false, nclasses = 1000) | ||
|
|
||
| layers = mlpmixer(imsize; inchannels, patch_size, planes, depth, expansion_factor, dropout, | ||
| nclasses) | ||
| pretrain && loadpretrain!(layers, string("MLPMixer")) | ||
| MLPMixer(layers) | ||
| end | ||
theabhirath marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| (m::MLPMixer)(x) = m.layers(x) | ||
|
|
||
| @functor MLPMixer | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.