Corrections
This commit is contained in:
parent
342abc44d3
commit
99f0c8bb20
|
@ -371,7 +371,7 @@ class ModelTester:
|
|||
|
||||
return
|
||||
|
||||
def slam_segmentation_test(self, net, test_loader, config, num_votes=100, debug=False):
|
||||
def slam_segmentation_test(self, net, test_loader, config, num_votes=100, debug=True):
|
||||
"""
|
||||
Test method for slam segmentation models
|
||||
"""
|
||||
|
@ -502,12 +502,13 @@ class ModelTester:
|
|||
if test_loader.dataset.set == 'validation':
|
||||
|
||||
# Insert false columns for ignored labels
|
||||
frame_probs_uint8_bis = frame_probs_uint8.copy()
|
||||
for l_ind, label_value in enumerate(test_loader.dataset.label_values):
|
||||
if label_value in test_loader.dataset.ignored_labels:
|
||||
frame_probs_uint8 = np.insert(frame_probs_uint8, l_ind, 0, axis=1)
|
||||
frame_probs_uint8_bis = np.insert(frame_probs_uint8_bis, l_ind, 0, axis=1)
|
||||
|
||||
# Predicted labels
|
||||
frame_preds = test_loader.dataset.label_values[np.argmax(frame_probs_uint8,
|
||||
frame_preds = test_loader.dataset.label_values[np.argmax(frame_probs_uint8_bis,
|
||||
axis=1)].astype(np.int32)
|
||||
|
||||
# Save some of the frame pots
|
||||
|
@ -528,6 +529,15 @@ class ModelTester:
|
|||
[frame_points[:, :3], frame_labels, frame_preds],
|
||||
['x', 'y', 'z', 'gt', 'pre'])
|
||||
|
||||
# Also Save lbl probabilities
|
||||
probpath = join(test_path, folder, filename[:-4] + '_probs.ply')
|
||||
lbl_names = [test_loader.dataset.label_to_names[l]
|
||||
for l in test_loader.dataset.label_values
|
||||
if l not in test_loader.dataset.ignored_labels]
|
||||
write_ply(probpath,
|
||||
[frame_points[:, :3], frame_probs_uint8],
|
||||
['x', 'y', 'z'] + lbl_names)
|
||||
|
||||
# keep frame preds in memory
|
||||
all_f_preds[s_ind][f_ind] = frame_preds
|
||||
all_f_labels[s_ind][f_ind] = frame_labels
|
||||
|
@ -575,8 +585,8 @@ class ModelTester:
|
|||
last_display = t[-1]
|
||||
message = 'e{:03d}-i{:04d} => {:.1f}% (timings : {:4.2f} {:4.2f} {:4.2f}) / pots {:d} => {:.1f}%'
|
||||
min_pot = int(torch.floor(torch.min(test_loader.dataset.potentials)))
|
||||
pot_num = torch.sum(test_loader.dataset.potentials > min_pot).type(torch.int32).item()
|
||||
current_num = pot_num + (i0 + 1 - config.validation_size) * config.val_batch_num
|
||||
pot_num = torch.sum(test_loader.dataset.potentials > min_pot + 0.5).type(torch.int32).item()
|
||||
current_num = pot_num + (i + 1 - config.validation_size) * config.val_batch_num
|
||||
print(message.format(test_epoch, i,
|
||||
100 * i / config.validation_size,
|
||||
1000 * (mean_dt[0]),
|
||||
|
|
Loading…
Reference in a new issue