Skip to content

Commit cabe8aa

Browse files
committed
Minor formatting tweaks
1 parent eb45dee commit cabe8aa

File tree

3 files changed

+10
-9
lines changed

3 files changed

+10
-9
lines changed

src/Metalhead.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ export AlexNet,
3737

3838
# use Flux._big_show to pretty print large models
3939
for T in (:AlexNet, :VGG, :ResNet, :GoogLeNet, :Inception3, :SqueezeNet, :DenseNet, :ResNeXt,
40-
:MobileNetv2, :MobileNetv3, :MLPMixer)
40+
:MobileNetv2, :MobileNetv3, :MLPMixer)
4141
@eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model)
4242
end
4343

src/other/mlpmixer.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ function mlpmixer(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = 1
4747

4848
classification_head = Chain(_seconddimmean, Dense(planes, nclasses))
4949

50-
return Chain(layers..., classification_head)
50+
return Chain(Chain(layers...), classification_head)
5151
end
5252

5353
struct MLPMixer
@@ -56,7 +56,7 @@ end
5656

5757
"""
5858
MLPMixer(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = 16, planes = 512,
59-
depth = 12, expansion_factor = 4, dropout = 0., pretrain = false, nclasses = 1000)
59+
depth = 12, expansion_factor = 4, dropout = 0., nclasses = 1000)
6060
6161
Creates a model with the MLPMixer architecture.
6262
([reference](https://arxiv.org/pdf/2105.01601)).
@@ -69,18 +69,19 @@ Creates a model with the MLPMixer architecture.
6969
- depth: the number of blocks in the main model
7070
- expansion_factor: the number of channels in each block
7171
- dropout: the dropout rate
72-
- pretrain: whether to load the pre-trained weights for ImageNet
7372
- nclasses: the number of classes in the output
7473
"""
7574
function MLPMixer(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = 16, planes = 512,
76-
depth = 12, expansion_factor = 4, dropout = 0., pretrain = false, nclasses = 1000)
75+
depth = 12, expansion_factor = 4, dropout = 0., nclasses = 1000)
7776

7877
layers = mlpmixer(imsize; inchannels, patch_size, planes, depth, expansion_factor, dropout,
7978
nclasses)
80-
pretrain && loadpretrain!(layers, string("MLPMixer"))
8179
MLPMixer(layers)
8280
end
8381

82+
@functor MLPMixer
83+
8484
(m::MLPMixer)(x) = m.layers(x)
8585

86-
@functor MLPMixer
86+
backbone(m::MLPMixer) = m.layers[1]
87+
classifier(m::MLPMixer) = m.layers[2:end]

test/other.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@ using Metalhead, Test
22
using Flux
33

44
@testset "MLPMixer" begin
5-
@test size(MLPMixer()(rand(Float32, 256, 256, 3, 67))) == (1000, 67)
6-
@test_skip gradtest(MLPMixer(), rand(Float32, 256, 256, 3, 67))
5+
@test size(MLPMixer()(rand(Float32, 256, 256, 3, 2))) == (1000, 2)
6+
@test_skip gradtest(MLPMixer(), rand(Float32, 256, 256, 3, 2))
77
end

0 commit comments

Comments
 (0)