diff --git a/unet_demo.py b/unet_demo.py index 5ae1bb0..6d34508 100644 --- a/unet_demo.py +++ b/unet_demo.py @@ -81,6 +81,11 @@ def __init__(self, in_channels, out_channels, n_class, kernel_size, self.down3 = DownConv(4 * out_channels, 8 * out_channels, kernel_size, padding, stride) + self.down4 = DownConv(8 * out_channels, 16 * out_channels, kernel_size, + padding, stride) + + self.up4 = UpConv(16 * out_channels, 8 * out_channels, 8 * out_channels, + kernel_size, padding, stride) self.up3 = UpConv(8 * out_channels, 4 * out_channels, 4 * out_channels, kernel_size, padding, stride) @@ -99,8 +104,10 @@ def forward(self, x): x1 = self.down1(x) x2 = self.down2(x1) x3 = self.down3(x2) + x4 = self.down4(x3) # Decoder - x_up = self.up3(x3, x2) + x_up = self.up4(x4, x3) + x_up = self.up3(x_up, x2) x_up = self.up2(x_up, x1) x_up = self.up1(x_up, x) x_out = F.log_softmax(self.out(x_up), 1)