mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-08 14:39:00 +00:00
feat: reduce the number of parameters in the net
Former-commit-id: 862569b6d284ec8235586b161d8c7055c006f5d8 [formerly f2e672d780df12a398e851f375a238c2d394a3cd] Former-commit-id: 740b1129a627c488537bb0d0dc7ff73b66fde813
This commit is contained in:
parent
9fe76d8c61
commit
dac6237906
4
.gitignore
vendored
4
.gitignore
vendored
|
@ -1,6 +1,10 @@
|
|||
.venv/
|
||||
.mypy_cache/
|
||||
__pycache__/
|
||||
|
||||
wandb/
|
||||
images/
|
||||
|
||||
*.pth
|
||||
*.png
|
||||
*.jpg
|
||||
|
|
4
.vscode/launch.json
vendored
4
.vscode/launch.json
vendored
|
@ -12,9 +12,9 @@
|
|||
"console": "integratedTerminal",
|
||||
"args": [
|
||||
"--input",
|
||||
"SM.png",
|
||||
"images/SM.png",
|
||||
"--output",
|
||||
"test.png",
|
||||
"output.png",
|
||||
],
|
||||
"justMyCode": true
|
||||
}
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
c6d08aa612451072cfe32a3ee086d08342ed9dd9
|
20
src/train.py
20
src/train.py
|
@ -18,7 +18,7 @@ from utils.paste import RandomPaste
|
|||
|
||||
CHECKPOINT_DIR = Path("./checkpoints/")
|
||||
DIR_TRAIN_IMG = Path("/home/lilian/data_disk/lfainsin/val2017")
|
||||
DIR_VALID_IMG = Path("/home/lilian/data_disk/lfainsin/smolval2017/")
|
||||
DIR_VALID_IMG = Path("/home/lilian/data_disk/lfainsin/smoltrain2017/")
|
||||
DIR_SPHERE_IMG = Path("/home/lilian/data_disk/lfainsin/spheres/Images/")
|
||||
DIR_SPHERE_MASK = Path("/home/lilian/data_disk/lfainsin/spheres/Masks/")
|
||||
|
||||
|
@ -41,7 +41,7 @@ def get_args():
|
|||
dest="batch_size",
|
||||
metavar="B",
|
||||
type=int,
|
||||
default=16,
|
||||
default=70,
|
||||
help="Batch size",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
@ -92,11 +92,14 @@ def main():
|
|||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
# 0. Create network
|
||||
net = UNet(n_channels=3, n_classes=args.classes)
|
||||
features = [16, 32, 64, 128]
|
||||
net = UNet(n_channels=args.n_channels, n_classes=args.classes, features=features)
|
||||
logging.info(
|
||||
f"""Network:
|
||||
input channels: {net.n_channels}
|
||||
output channels: {net.n_classes}
|
||||
nb parameters: {sum(p.numel() for p in net.parameters() if p.requires_grad)}
|
||||
features: {features}
|
||||
"""
|
||||
)
|
||||
|
||||
|
@ -138,7 +141,7 @@ def main():
|
|||
ds_valid = SphereDataset(image_dir=DIR_VALID_IMG, transform=tf_valid)
|
||||
|
||||
# 3. Create data loaders
|
||||
loader_args = dict(batch_size=args.batch_size, num_workers=6, pin_memory=True)
|
||||
loader_args = dict(batch_size=args.batch_size, num_workers=8, pin_memory=True)
|
||||
train_loader = DataLoader(ds_train, shuffle=True, **loader_args)
|
||||
val_loader = DataLoader(ds_valid, shuffle=False, drop_last=True, **loader_args)
|
||||
|
||||
|
@ -159,9 +162,9 @@ def main():
|
|||
),
|
||||
)
|
||||
wandb.watch(net, log_freq=100)
|
||||
# artifact = wandb.Artifact("model", type="model")
|
||||
# artifact.add_file("model.pth")
|
||||
# run.log_artifact(artifact)
|
||||
artifact_model = wandb.Artifact("model", type="model")
|
||||
artifact_model.add_file("model.pth")
|
||||
run.log_artifact(artifact_model)
|
||||
|
||||
logging.info(
|
||||
f"""Starting training:
|
||||
|
@ -228,8 +231,7 @@ def main():
|
|||
print(f"Train Loss: {train_loss:.3f}, Valid Score: {val_score:3f}")
|
||||
|
||||
# save weights when epoch end
|
||||
# torch.save(net.state_dict(), "model.pth")
|
||||
# run.log_artifact(artifact)
|
||||
torch.save(net.state_dict(), "model.pth")
|
||||
logging.info(f"model saved!")
|
||||
|
||||
run.finish()
|
||||
|
|
Loading…
Reference in a new issue