From dac6237906174cb206c64b9dbedd7c0a65441151 Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Thu, 30 Jun 2022 10:47:53 +0200 Subject: [PATCH] feat: reduce the number of parameters in the net Former-commit-id: 862569b6d284ec8235586b161d8c7055c006f5d8 [formerly f2e672d780df12a398e851f375a238c2d394a3cd] Former-commit-id: 740b1129a627c488537bb0d0dc7ff73b66fde813 --- .gitignore | 4 ++++ .vscode/launch.json | 4 ++-- SM.png.REMOVED.git-id | 1 - src/train.py | 20 +++++++++++--------- test.png | Bin 1575 -> 0 bytes 5 files changed, 17 insertions(+), 12 deletions(-) delete mode 100644 SM.png.REMOVED.git-id delete mode 100644 test.png diff --git a/.gitignore b/.gitignore index 9d60c97..da4c1bc 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,10 @@ .venv/ .mypy_cache/ __pycache__/ + wandb/ +images/ *.pth +*.png +*.jpg diff --git a/.vscode/launch.json b/.vscode/launch.json index f7a9165..bccf130 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -12,9 +12,9 @@ "console": "integratedTerminal", "args": [ "--input", - "SM.png", + "images/SM.png", "--output", - "test.png", + "output.png", ], "justMyCode": true } diff --git a/SM.png.REMOVED.git-id b/SM.png.REMOVED.git-id deleted file mode 100644 index 21a947e..0000000 --- a/SM.png.REMOVED.git-id +++ /dev/null @@ -1 +0,0 @@ -c6d08aa612451072cfe32a3ee086d08342ed9dd9 \ No newline at end of file diff --git a/src/train.py b/src/train.py index 4f56bde..49f6dc5 100644 --- a/src/train.py +++ b/src/train.py @@ -18,7 +18,7 @@ from utils.paste import RandomPaste CHECKPOINT_DIR = Path("./checkpoints/") DIR_TRAIN_IMG = Path("/home/lilian/data_disk/lfainsin/val2017") -DIR_VALID_IMG = Path("/home/lilian/data_disk/lfainsin/smolval2017/") +DIR_VALID_IMG = Path("/home/lilian/data_disk/lfainsin/smoltrain2017/") DIR_SPHERE_IMG = Path("/home/lilian/data_disk/lfainsin/spheres/Images/") DIR_SPHERE_MASK = Path("/home/lilian/data_disk/lfainsin/spheres/Masks/") @@ -41,7 +41,7 @@ def get_args(): dest="batch_size", metavar="B", type=int, - default=16, + default=70, help="Batch size", ) parser.add_argument( @@ -92,11 +92,14 @@ def main(): torch.backends.cudnn.benchmark = True # 0. Create network - net = UNet(n_channels=3, n_classes=args.classes) + features = [16, 32, 64, 128] + net = UNet(n_channels=args.n_channels, n_classes=args.classes, features=features) logging.info( f"""Network: input channels: {net.n_channels} output channels: {net.n_classes} + nb parameters: {sum(p.numel() for p in net.parameters() if p.requires_grad)} + features: {features} """ ) @@ -138,7 +141,7 @@ def main(): ds_valid = SphereDataset(image_dir=DIR_VALID_IMG, transform=tf_valid) # 3. Create data loaders - loader_args = dict(batch_size=args.batch_size, num_workers=6, pin_memory=True) + loader_args = dict(batch_size=args.batch_size, num_workers=8, pin_memory=True) train_loader = DataLoader(ds_train, shuffle=True, **loader_args) val_loader = DataLoader(ds_valid, shuffle=False, drop_last=True, **loader_args) @@ -159,9 +162,9 @@ def main(): ), ) wandb.watch(net, log_freq=100) - # artifact = wandb.Artifact("model", type="model") - # artifact.add_file("model.pth") - # run.log_artifact(artifact) + artifact_model = wandb.Artifact("model", type="model") + artifact_model.add_file("model.pth") + run.log_artifact(artifact_model) logging.info( f"""Starting training: @@ -228,8 +231,7 @@ def main(): print(f"Train Loss: {train_loss:.3f}, Valid Score: {val_score:3f}") # save weights when epoch end - # torch.save(net.state_dict(), "model.pth") - # run.log_artifact(artifact) + torch.save(net.state_dict(), "model.pth") logging.info(f"model saved!") run.finish() diff --git a/test.png b/test.png deleted file mode 100644 index a0cad88ce1832fdd02c3e4968e3e838520b3af20..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1575 zcmZ9Mdr*=I6ve5ru{9H$+H5mvEz3u1&&?Vi5oHmUFU%eqkk*c&NoBxiAzC>3Sgl+% zwG{JhrC^$-Kwws5rHPdW3Vdp1k`Ny%3Lzp}?e5g<{p&aPk281fIdjs30?_lVmRiAJ zu=$uhzE~J+c0CMcftdAW0kgIqV!&W?AdD|+e|(`DM3UsYm!yp4^{fuF*!^JS$P(L( zWe@#r?NIxMoIH2<*`an+J)aeKahW4}Z~@t3vFnjJ_S-Mn)=;t5&p)))&F|AV6H65b z<4n9j0I^~YJ++y6r^q7HmGsQCy z*NN?PyPlfugsVMXbgF15T@Fqm;mCeBIE&90?NmXe@t4i~d16v_>SexcIKo}~K5`nQ z3fzvaXu=9iKxoKO7{*M zhMf^f0Pwk63hrGR6{{Xr+?yA*(db%{({Dbug_>KbUmys;`W@7m>)@xQDaXGV_TzK@ z;pWh&*dV~Z+b6xZX!{M_h0%=ON)$8ETA5g2X%xBytBk+HgN>qIpY~+BF!<`cqeOYn zt3)%-lClK_o(0xdw>_<+2}5d0oo{c;;%Sn+_pO8KfJZu4`xNWh^amj}QtMdc zZ`p4+V3Ml~!9c}IjR~l`_hG%-Mt$1jroT9&aVEYqVSiVY%i<2hO!Oy_5|t-lixIrb zy31_ZRwKJ&qDR8w>g9MU!4RiCe243`Cer-88B$uZc98=U#pjm44BbqkMiT}OE|93Q zCU8$CJr@CZllLwAIQn173|@6HhrydMh1gjsr}Q+O%LottMstb1b01*ev59R)>a*shzl%PBctAOw)5(Kf3D15Ad zkFFia%&oo3!d9g{mA$?G8WdzXHKHW}am_j9ocb<3FJx8YjACHG>9!sqC>)2(1Lr$6 z2a#te)j1O@jtWOUdnOa4k61kjj_E@oU@TQmw0!fE=O-#%h8>YvY zA{ZAWiFWA0Iq=v`-W0F1jhR~IGf3Agi&t$>-6fD)Hf+|$T;%dM{DGdK%;h5Pm#d0N zi9i&@85qmnvT2W_mQE)l!y7py;oFqyyap}(^lugF2%+=+nlRjpb#vP~h=;v}Ol9I! zfyVk>?PN*ri>lQb#=~|!jbUV2G3QOG%!0CI0VpzLR+bsU2yS_2@jLYHIF)lm6-kq+ zJRoj8-zvE6wY4k*q9!hsshrf(G>*Dq>G$BZuTL5-%j#O#+(I4O-IPU1!?v8UorJ?& z0szUvW@jnC=%LT6%wxosb%-fVNhdCFFvSD#xk!ujj7a);H6u+^yY^bWO>imx#2fR{ zvZ>LyB&26pZvx*thiZ0b$4G|Xw$|<>-ta!npVK$=jZ1a!(TgdwmsKX?+knJK;WW5=7-8evhCouPZ6~gQa@a60{%J>`8b3oJp