mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-09 15:02:03 +00:00
feat: made model.py extensible
Former-commit-id: 0261a3f6caf571dc19e6ca97d2ecada1e72c7f04
This commit is contained in:
parent
bc2aeacfa3
commit
a42190ec61
|
@ -20,20 +20,24 @@ class UNet(nn.Module):
|
||||||
self.ups = nn.ModuleList()
|
self.ups = nn.ModuleList()
|
||||||
for i in range(len(features) - 1):
|
for i in range(len(features) - 1):
|
||||||
self.ups.append(
|
self.ups.append(
|
||||||
Up(*features[::-1][i : i + 2]),
|
Up(*features[-1 - i : -1 - i + 3 : -1]),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.outc = OutConv(features[0], n_classes)
|
self.outc = OutConv(features[0], n_classes)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x1 = self.inc(x)
|
|
||||||
x2 = self.down1(x1)
|
skips = []
|
||||||
x3 = self.down2(x2)
|
|
||||||
x4 = self.down3(x3)
|
x = self.inc(x)
|
||||||
x5 = self.down4(x4)
|
|
||||||
x = self.up1(x5, x4)
|
for down in self.downs:
|
||||||
x = self.up2(x, x3)
|
skips.append(x)
|
||||||
x = self.up3(x, x2)
|
x = down(x)
|
||||||
x = self.up4(x, x1)
|
|
||||||
logits = self.outc(x)
|
for up, skip in zip(self.ups, reversed(skips)):
|
||||||
return logits
|
x = up(x, skip)
|
||||||
|
|
||||||
|
x = self.outc(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
Loading…
Reference in a new issue