@@ -47,7 +47,7 @@ function mlpmixer(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = 1
4747
4848 classification_head = Chain (_seconddimmean, Dense (planes, nclasses))
4949
50- return Chain (layers... , classification_head)
50+ return Chain (Chain ( layers... ) , classification_head)
5151end
5252
5353struct MLPMixer
5656
5757"""
5858 MLPMixer(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = 16, planes = 512,
59- depth = 12, expansion_factor = 4, dropout = 0., pretrain = false, nclasses = 1000)
59+ depth = 12, expansion_factor = 4, dropout = 0., nclasses = 1000)
6060
6161Creates a model with the MLPMixer architecture.
6262([reference](https://arxiv.org/pdf/2105.01601)).
@@ -69,18 +69,19 @@ Creates a model with the MLPMixer architecture.
6969- depth: the number of blocks in the main model
7070- expansion_factor: the number of channels in each block
7171- dropout: the dropout rate
72- - pretrain: whether to load the pre-trained weights for ImageNet
7372- nclasses: the number of classes in the output
7473"""
7574function MLPMixer (imsize:: NTuple{2} = (256 , 256 ); inchannels = 3 , patch_size = 16 , planes = 512 ,
76- depth = 12 , expansion_factor = 4 , dropout = 0. , pretrain = false , nclasses = 1000 )
75+ depth = 12 , expansion_factor = 4 , dropout = 0. , nclasses = 1000 )
7776
7877 layers = mlpmixer (imsize; inchannels, patch_size, planes, depth, expansion_factor, dropout,
7978 nclasses)
80- pretrain && loadpretrain! (layers, string (" MLPMixer" ))
8179 MLPMixer (layers)
8280end
8381
82+ @functor MLPMixer
83+
8484(m:: MLPMixer )(x) = m. layers (x)
8585
86- @functor MLPMixer
86+ backbone (m:: MLPMixer ) = m. layers[1 ]
87+ classifier (m:: MLPMixer ) = m. layers[2 : end ]
0 commit comments