From 4cebd543d796ce3c324cb9f39486deab97133bfe Mon Sep 17 00:00:00 2001 From: HuguesTHOMAS Date: Mon, 11 Apr 2022 09:21:19 -0400 Subject: [PATCH] Test set: correct ignored labels column --- utils/tester.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/utils/tester.py b/utils/tester.py index c51aaec..f2f0357 100644 --- a/utils/tester.py +++ b/utils/tester.py @@ -369,6 +369,11 @@ class ModelTester: probs = self.test_probs[i][test_loader.dataset.test_proj[i], :] proj_probs += [probs] + # Insert false columns for ignored labels + for l_ind, label_value in enumerate(test_loader.dataset.label_values): + if label_value in test_loader.dataset.ignored_labels: + proj_probs[i] = np.insert(proj_probs[i], l_ind, 0, axis=1) + t2 = time.time() print('Done in {:.1f} s\n'.format(t2 - t1)) @@ -379,11 +384,6 @@ class ModelTester: Confs = [] for i, file_path in enumerate(test_loader.dataset.files): - # Insert false columns for ignored labels - for l_ind, label_value in enumerate(test_loader.dataset.label_values): - if label_value in test_loader.dataset.ignored_labels: - proj_probs[i] = np.insert(proj_probs[i], l_ind, 0, axis=1) - # Get the predicted labels preds = test_loader.dataset.label_values[np.argmax(proj_probs[i], axis=1)].astype(np.int32)