Test set: correct ignored labels column
This commit is contained in:
parent
7255680ff0
commit
4cebd543d7
|
@ -369,6 +369,11 @@ class ModelTester:
|
||||||
probs = self.test_probs[i][test_loader.dataset.test_proj[i], :]
|
probs = self.test_probs[i][test_loader.dataset.test_proj[i], :]
|
||||||
proj_probs += [probs]
|
proj_probs += [probs]
|
||||||
|
|
||||||
|
# Insert false columns for ignored labels
|
||||||
|
for l_ind, label_value in enumerate(test_loader.dataset.label_values):
|
||||||
|
if label_value in test_loader.dataset.ignored_labels:
|
||||||
|
proj_probs[i] = np.insert(proj_probs[i], l_ind, 0, axis=1)
|
||||||
|
|
||||||
t2 = time.time()
|
t2 = time.time()
|
||||||
print('Done in {:.1f} s\n'.format(t2 - t1))
|
print('Done in {:.1f} s\n'.format(t2 - t1))
|
||||||
|
|
||||||
|
@ -379,11 +384,6 @@ class ModelTester:
|
||||||
Confs = []
|
Confs = []
|
||||||
for i, file_path in enumerate(test_loader.dataset.files):
|
for i, file_path in enumerate(test_loader.dataset.files):
|
||||||
|
|
||||||
# Insert false columns for ignored labels
|
|
||||||
for l_ind, label_value in enumerate(test_loader.dataset.label_values):
|
|
||||||
if label_value in test_loader.dataset.ignored_labels:
|
|
||||||
proj_probs[i] = np.insert(proj_probs[i], l_ind, 0, axis=1)
|
|
||||||
|
|
||||||
# Get the predicted labels
|
# Get the predicted labels
|
||||||
preds = test_loader.dataset.label_values[np.argmax(proj_probs[i], axis=1)].astype(np.int32)
|
preds = test_loader.dataset.label_values[np.argmax(proj_probs[i], axis=1)].astype(np.int32)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue