From a42190ec6160f11e25365e6ef02ab39862b5007e Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 27 Jun 2022 16:31:07 +0200 Subject: [PATCH] feat: made model.py extensible Former-commit-id: 0261a3f6caf571dc19e6ca97d2ecada1e72c7f04 --- src/unet/model.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/src/unet/model.py b/src/unet/model.py index b1bb22c..73872c5 100644 --- a/src/unet/model.py +++ b/src/unet/model.py @@ -20,20 +20,24 @@ class UNet(nn.Module): self.ups = nn.ModuleList() for i in range(len(features) - 1): self.ups.append( - Up(*features[::-1][i : i + 2]), + Up(*features[-1 - i : -1 - i + 3 : -1]), ) self.outc = OutConv(features[0], n_classes) def forward(self, x): - x1 = self.inc(x) - x2 = self.down1(x1) - x3 = self.down2(x2) - x4 = self.down3(x3) - x5 = self.down4(x4) - x = self.up1(x5, x4) - x = self.up2(x, x3) - x = self.up3(x, x2) - x = self.up4(x, x1) - logits = self.outc(x) - return logits + + skips = [] + + x = self.inc(x) + + for down in self.downs: + skips.append(x) + x = down(x) + + for up, skip in zip(self.ups, reversed(skips)): + x = up(x, skip) + + x = self.outc(x) + + return x