From 5cd2a3b0b785d3f40af6b06bc5a559a441ad7549 Mon Sep 17 00:00:00 2001 From: milesial Date: Mon, 16 Aug 2021 06:08:09 +0200 Subject: [PATCH] Restructure Former-commit-id: beadb49b75ea79a3c0f95df589f64a8274419c5b --- README.md | 2 +- evaluate.py | 2 +- predict.py | 2 +- download_data.sh => scripts/download_data.sh | 0 train.py | 4 ++-- utils/__init__.py | 0 data_loading.py => utils/data_loading.py | 0 dice_score.py => utils/dice_score.py | 0 utils.py => utils/utils.py | 0 9 files changed, 5 insertions(+), 5 deletions(-) rename download_data.sh => scripts/download_data.sh (100%) create mode 100644 utils/__init__.py rename data_loading.py => utils/data_loading.py (100%) rename dice_score.py => utils/dice_score.py (100%) rename utils.py => utils/utils.py (100%) diff --git a/README.md b/README.md index 99549cc..460b5cc 100644 --- a/README.md +++ b/README.md @@ -109,7 +109,7 @@ The Carvana data is available on the [Kaggle website](https://www.kaggle.com/c/c You can also download it using your Kaggle API key with: ```shell script -bash download_data.sh +bash scripts/download_data.sh ``` ## Notes on memory diff --git a/evaluate.py b/evaluate.py index 053d726..88b337b 100644 --- a/evaluate.py +++ b/evaluate.py @@ -2,7 +2,7 @@ import torch import torch.nn.functional as F from tqdm import tqdm -from dice_score import multiclass_dice_coeff +from utils.dice_score import multiclass_dice_coeff def evaluate(net, dataloader, device): diff --git a/predict.py b/predict.py index 348b27c..296fdb8 100755 --- a/predict.py +++ b/predict.py @@ -8,7 +8,7 @@ import torch.nn.functional as F from PIL import Image from torchvision import transforms -from data_loading import BasicDataset +from utils.data_loading import BasicDataset from unet import UNet from utils import plot_img_and_mask diff --git a/download_data.sh b/scripts/download_data.sh similarity index 100% rename from download_data.sh rename to scripts/download_data.sh diff --git a/train.py b/train.py index d10df0b..7014da1 100644 --- a/train.py +++ b/train.py @@ -11,8 +11,8 @@ from torch import optim from torch.utils.data import DataLoader, random_split from tqdm import tqdm -from data_loading import BasicDataset, CarvanaDataset -from dice_score import dice_loss +from utils.data_loading import BasicDataset, CarvanaDataset +from utils.dice_score import dice_loss from evaluate import evaluate from unet import UNet diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data_loading.py b/utils/data_loading.py similarity index 100% rename from data_loading.py rename to utils/data_loading.py diff --git a/dice_score.py b/utils/dice_score.py similarity index 100% rename from dice_score.py rename to utils/dice_score.py diff --git a/utils.py b/utils/utils.py similarity index 100% rename from utils.py rename to utils/utils.py