Corrections

This commit is contained in:
HuguesTHOMAS 2020-04-29 14:39:39 -04:00
parent 342abc44d3
commit 99f0c8bb20

View file

@ -371,7 +371,7 @@ class ModelTester:
return
def slam_segmentation_test(self, net, test_loader, config, num_votes=100, debug=False):
def slam_segmentation_test(self, net, test_loader, config, num_votes=100, debug=True):
"""
Test method for slam segmentation models
"""
@ -502,12 +502,13 @@ class ModelTester:
if test_loader.dataset.set == 'validation':
# Insert false columns for ignored labels
frame_probs_uint8_bis = frame_probs_uint8.copy()
for l_ind, label_value in enumerate(test_loader.dataset.label_values):
if label_value in test_loader.dataset.ignored_labels:
frame_probs_uint8 = np.insert(frame_probs_uint8, l_ind, 0, axis=1)
frame_probs_uint8_bis = np.insert(frame_probs_uint8_bis, l_ind, 0, axis=1)
# Predicted labels
frame_preds = test_loader.dataset.label_values[np.argmax(frame_probs_uint8,
frame_preds = test_loader.dataset.label_values[np.argmax(frame_probs_uint8_bis,
axis=1)].astype(np.int32)
# Save some of the frame pots
@ -528,6 +529,15 @@ class ModelTester:
[frame_points[:, :3], frame_labels, frame_preds],
['x', 'y', 'z', 'gt', 'pre'])
# Also Save lbl probabilities
probpath = join(test_path, folder, filename[:-4] + '_probs.ply')
lbl_names = [test_loader.dataset.label_to_names[l]
for l in test_loader.dataset.label_values
if l not in test_loader.dataset.ignored_labels]
write_ply(probpath,
[frame_points[:, :3], frame_probs_uint8],
['x', 'y', 'z'] + lbl_names)
# keep frame preds in memory
all_f_preds[s_ind][f_ind] = frame_preds
all_f_labels[s_ind][f_ind] = frame_labels
@ -575,8 +585,8 @@ class ModelTester:
last_display = t[-1]
message = 'e{:03d}-i{:04d} => {:.1f}% (timings : {:4.2f} {:4.2f} {:4.2f}) / pots {:d} => {:.1f}%'
min_pot = int(torch.floor(torch.min(test_loader.dataset.potentials)))
pot_num = torch.sum(test_loader.dataset.potentials > min_pot).type(torch.int32).item()
current_num = pot_num + (i0 + 1 - config.validation_size) * config.val_batch_num
pot_num = torch.sum(test_loader.dataset.potentials > min_pot + 0.5).type(torch.int32).item()
current_num = pot_num + (i + 1 - config.validation_size) * config.val_batch_num
print(message.format(test_epoch, i,
100 * i / config.validation_size,
1000 * (mean_dt[0]),