Test set: correct ignored labels column
This commit is contained in:
parent
7255680ff0
commit
4cebd543d7
|
@ -369,6 +369,11 @@ class ModelTester:
|
|||
probs = self.test_probs[i][test_loader.dataset.test_proj[i], :]
|
||||
proj_probs += [probs]
|
||||
|
||||
# Insert false columns for ignored labels
|
||||
for l_ind, label_value in enumerate(test_loader.dataset.label_values):
|
||||
if label_value in test_loader.dataset.ignored_labels:
|
||||
proj_probs[i] = np.insert(proj_probs[i], l_ind, 0, axis=1)
|
||||
|
||||
t2 = time.time()
|
||||
print('Done in {:.1f} s\n'.format(t2 - t1))
|
||||
|
||||
|
@ -379,11 +384,6 @@ class ModelTester:
|
|||
Confs = []
|
||||
for i, file_path in enumerate(test_loader.dataset.files):
|
||||
|
||||
# Insert false columns for ignored labels
|
||||
for l_ind, label_value in enumerate(test_loader.dataset.label_values):
|
||||
if label_value in test_loader.dataset.ignored_labels:
|
||||
proj_probs[i] = np.insert(proj_probs[i], l_ind, 0, axis=1)
|
||||
|
||||
# Get the predicted labels
|
||||
preds = test_loader.dataset.label_values[np.argmax(proj_probs[i], axis=1)].astype(np.int32)
|
||||
|
||||
|
|
Loading…
Reference in a new issue