diff --git a/src/predict.py b/src/predict.py index 9e04255..ec03fe6 100755 --- a/src/predict.py +++ b/src/predict.py @@ -8,7 +8,7 @@ import torch.nn.functional as F from PIL import Image from torchvision import transforms -from utils.data_loading import BasicDataset +from src.utils.dataset import BasicDataset from unet import UNet from utils.utils import plot_img_and_mask diff --git a/src/train.py b/src/train.py index a98655e..7e0b7ab 100644 --- a/src/train.py +++ b/src/train.py @@ -11,8 +11,8 @@ from torch.utils.data import DataLoader, random_split from tqdm import tqdm from evaluate import evaluate +from src.utils.dataset import BasicDataset, CarvanaDataset from unet import UNet -from utils.data_loading import BasicDataset, CarvanaDataset from utils.dice_score import dice_loss dir_img = Path("./data/imgs/") diff --git a/src/unet/__init__.py b/src/unet/__init__.py index 2e9b63b..ed74c60 100644 --- a/src/unet/__init__.py +++ b/src/unet/__init__.py @@ -1 +1 @@ -from .unet_model import UNet +from .model import UNet diff --git a/src/unet/unet_parts.py b/src/unet/blocks.py similarity index 100% rename from src/unet/unet_parts.py rename to src/unet/blocks.py diff --git a/src/unet/unet_model.py b/src/unet/model.py similarity index 97% rename from src/unet/unet_model.py rename to src/unet/model.py index 20c35b5..9fa9c0e 100644 --- a/src/unet/unet_model.py +++ b/src/unet/model.py @@ -1,6 +1,6 @@ """ Full assembly of the parts to form the complete network """ -from .unet_parts import * +from .blocks import * class UNet(nn.Module): diff --git a/src/utils/data_loading.py b/src/utils/dataset.py similarity index 100% rename from src/utils/data_loading.py rename to src/utils/dataset.py