Test set: correct ignored labels column

This commit is contained in:
HuguesTHOMAS 2022-04-11 09:21:19 -04:00
parent 7255680ff0
commit 4cebd543d7

View file

@ -369,6 +369,11 @@ class ModelTester:
probs = self.test_probs[i][test_loader.dataset.test_proj[i], :]
proj_probs += [probs]
# Insert false columns for ignored labels
for l_ind, label_value in enumerate(test_loader.dataset.label_values):
if label_value in test_loader.dataset.ignored_labels:
proj_probs[i] = np.insert(proj_probs[i], l_ind, 0, axis=1)
t2 = time.time()
print('Done in {:.1f} s\n'.format(t2 - t1))
@ -379,11 +384,6 @@ class ModelTester:
Confs = []
for i, file_path in enumerate(test_loader.dataset.files):
# Insert false columns for ignored labels
for l_ind, label_value in enumerate(test_loader.dataset.label_values):
if label_value in test_loader.dataset.ignored_labels:
proj_probs[i] = np.insert(proj_probs[i], l_ind, 0, axis=1)
# Get the predicted labels
preds = test_loader.dataset.label_values[np.argmax(proj_probs[i], axis=1)].astype(np.int32)