feat: reduce the number of parameters in the net

Former-commit-id: 862569b6d284ec8235586b161d8c7055c006f5d8 [formerly f2e672d780df12a398e851f375a238c2d394a3cd]
Former-commit-id: 740b1129a627c488537bb0d0dc7ff73b66fde813
This commit is contained in:
Laurent Fainsin 2022-06-30 10:47:53 +02:00
parent 9fe76d8c61
commit dac6237906
5 changed files with 17 additions and 12 deletions

4
.gitignore vendored
View file

@ -1,6 +1,10 @@
.venv/ .venv/
.mypy_cache/ .mypy_cache/
__pycache__/ __pycache__/
wandb/ wandb/
images/
*.pth *.pth
*.png
*.jpg

4
.vscode/launch.json vendored
View file

@ -12,9 +12,9 @@
"console": "integratedTerminal", "console": "integratedTerminal",
"args": [ "args": [
"--input", "--input",
"SM.png", "images/SM.png",
"--output", "--output",
"test.png", "output.png",
], ],
"justMyCode": true "justMyCode": true
} }

View file

@ -1 +0,0 @@
c6d08aa612451072cfe32a3ee086d08342ed9dd9

View file

@ -18,7 +18,7 @@ from utils.paste import RandomPaste
CHECKPOINT_DIR = Path("./checkpoints/") CHECKPOINT_DIR = Path("./checkpoints/")
DIR_TRAIN_IMG = Path("/home/lilian/data_disk/lfainsin/val2017") DIR_TRAIN_IMG = Path("/home/lilian/data_disk/lfainsin/val2017")
DIR_VALID_IMG = Path("/home/lilian/data_disk/lfainsin/smolval2017/") DIR_VALID_IMG = Path("/home/lilian/data_disk/lfainsin/smoltrain2017/")
DIR_SPHERE_IMG = Path("/home/lilian/data_disk/lfainsin/spheres/Images/") DIR_SPHERE_IMG = Path("/home/lilian/data_disk/lfainsin/spheres/Images/")
DIR_SPHERE_MASK = Path("/home/lilian/data_disk/lfainsin/spheres/Masks/") DIR_SPHERE_MASK = Path("/home/lilian/data_disk/lfainsin/spheres/Masks/")
@ -41,7 +41,7 @@ def get_args():
dest="batch_size", dest="batch_size",
metavar="B", metavar="B",
type=int, type=int,
default=16, default=70,
help="Batch size", help="Batch size",
) )
parser.add_argument( parser.add_argument(
@ -92,11 +92,14 @@ def main():
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
# 0. Create network # 0. Create network
net = UNet(n_channels=3, n_classes=args.classes) features = [16, 32, 64, 128]
net = UNet(n_channels=args.n_channels, n_classes=args.classes, features=features)
logging.info( logging.info(
f"""Network: f"""Network:
input channels: {net.n_channels} input channels: {net.n_channels}
output channels: {net.n_classes} output channels: {net.n_classes}
nb parameters: {sum(p.numel() for p in net.parameters() if p.requires_grad)}
features: {features}
""" """
) )
@ -138,7 +141,7 @@ def main():
ds_valid = SphereDataset(image_dir=DIR_VALID_IMG, transform=tf_valid) ds_valid = SphereDataset(image_dir=DIR_VALID_IMG, transform=tf_valid)
# 3. Create data loaders # 3. Create data loaders
loader_args = dict(batch_size=args.batch_size, num_workers=6, pin_memory=True) loader_args = dict(batch_size=args.batch_size, num_workers=8, pin_memory=True)
train_loader = DataLoader(ds_train, shuffle=True, **loader_args) train_loader = DataLoader(ds_train, shuffle=True, **loader_args)
val_loader = DataLoader(ds_valid, shuffle=False, drop_last=True, **loader_args) val_loader = DataLoader(ds_valid, shuffle=False, drop_last=True, **loader_args)
@ -159,9 +162,9 @@ def main():
), ),
) )
wandb.watch(net, log_freq=100) wandb.watch(net, log_freq=100)
# artifact = wandb.Artifact("model", type="model") artifact_model = wandb.Artifact("model", type="model")
# artifact.add_file("model.pth") artifact_model.add_file("model.pth")
# run.log_artifact(artifact) run.log_artifact(artifact_model)
logging.info( logging.info(
f"""Starting training: f"""Starting training:
@ -228,8 +231,7 @@ def main():
print(f"Train Loss: {train_loss:.3f}, Valid Score: {val_score:3f}") print(f"Train Loss: {train_loss:.3f}, Valid Score: {val_score:3f}")
# save weights when epoch end # save weights when epoch end
# torch.save(net.state_dict(), "model.pth") torch.save(net.state_dict(), "model.pth")
# run.log_artifact(artifact)
logging.info(f"model saved!") logging.info(f"model saved!")
run.finish() run.finish()

BIN
test.png

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.5 KiB