-
Notifications
You must be signed in to change notification settings - Fork 2.7k
Description
"I noticed a potential channel mismatch issue in the Up module (unet/unet_parts.py) when using the bilinear=True option.
For example, in unet_model.py, the first Up module is called as self.up1(x5, x4), where x5 has 1024 channels and x4 (from the skip connection) has 512 channels.
When bilinear=True, the upsampled x5 still has 1024 channels.
The concatenated tensor torch.cat([x4, x1_upsampled], dim=1) will have 512 + 1024 = 1536 channels.
However, the convolution layer self.conv is initialized as DoubleConv(1024, 512, ...) which expects a 1024-channel input.
This will lead to a runtime size mismatch error."
"I believe a potential fix could be changing line 51 in unet/unet_parts.py from:
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
to:
self.conv = DoubleConv(in_channels + in_channels // 2, out_channels)
This would correctly handle the concatenated channel dimension of 1536."