diff --git a/src/evaluate.py b/src/evaluate.py index c67a60b..bdaeb78 100644 --- a/src/evaluate.py +++ b/src/evaluate.py @@ -2,7 +2,7 @@ import torch import torch.nn.functional as F from tqdm import tqdm -from utils.dice_score import multiclass_dice_coeff, dice_coeff +from src.utils.dice import multiclass_dice_coeff, dice_coeff def evaluate(net, dataloader, device): diff --git a/src/train.py b/src/train.py index 3e559a7..690273e 100644 --- a/src/train.py +++ b/src/train.py @@ -13,7 +13,7 @@ from tqdm import tqdm from evaluate import evaluate from src.utils.dataset import BasicDataset, CarvanaDataset from unet import UNet -from utils.dice_score import dice_loss +from src.utils.dice import dice_loss dir_img = Path("./data/imgs/") dir_mask = Path("./data/masks/") diff --git a/src/utils/dice_score.py b/src/utils/dice.py similarity index 100% rename from src/utils/dice_score.py rename to src/utils/dice.py