feat: made model.py extensible

Former-commit-id: 0261a3f6caf571dc19e6ca97d2ecada1e72c7f04
This commit is contained in:
Your Name 2022-06-27 16:31:07 +02:00
parent bc2aeacfa3
commit a42190ec61

View file

@ -20,20 +20,24 @@ class UNet(nn.Module):
self.ups = nn.ModuleList() self.ups = nn.ModuleList()
for i in range(len(features) - 1): for i in range(len(features) - 1):
self.ups.append( self.ups.append(
Up(*features[::-1][i : i + 2]), Up(*features[-1 - i : -1 - i + 3 : -1]),
) )
self.outc = OutConv(features[0], n_classes) self.outc = OutConv(features[0], n_classes)
def forward(self, x): def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1) skips = []
x3 = self.down2(x2)
x4 = self.down3(x3) x = self.inc(x)
x5 = self.down4(x4)
x = self.up1(x5, x4) for down in self.downs:
x = self.up2(x, x3) skips.append(x)
x = self.up3(x, x2) x = down(x)
x = self.up4(x, x1)
logits = self.outc(x) for up, skip in zip(self.ups, reversed(skips)):
return logits x = up(x, skip)
x = self.outc(x)
return x