Corrections
This commit is contained in:
parent
342abc44d3
commit
99f0c8bb20
|
@ -371,7 +371,7 @@ class ModelTester:
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
def slam_segmentation_test(self, net, test_loader, config, num_votes=100, debug=False):
|
def slam_segmentation_test(self, net, test_loader, config, num_votes=100, debug=True):
|
||||||
"""
|
"""
|
||||||
Test method for slam segmentation models
|
Test method for slam segmentation models
|
||||||
"""
|
"""
|
||||||
|
@ -502,12 +502,13 @@ class ModelTester:
|
||||||
if test_loader.dataset.set == 'validation':
|
if test_loader.dataset.set == 'validation':
|
||||||
|
|
||||||
# Insert false columns for ignored labels
|
# Insert false columns for ignored labels
|
||||||
|
frame_probs_uint8_bis = frame_probs_uint8.copy()
|
||||||
for l_ind, label_value in enumerate(test_loader.dataset.label_values):
|
for l_ind, label_value in enumerate(test_loader.dataset.label_values):
|
||||||
if label_value in test_loader.dataset.ignored_labels:
|
if label_value in test_loader.dataset.ignored_labels:
|
||||||
frame_probs_uint8 = np.insert(frame_probs_uint8, l_ind, 0, axis=1)
|
frame_probs_uint8_bis = np.insert(frame_probs_uint8_bis, l_ind, 0, axis=1)
|
||||||
|
|
||||||
# Predicted labels
|
# Predicted labels
|
||||||
frame_preds = test_loader.dataset.label_values[np.argmax(frame_probs_uint8,
|
frame_preds = test_loader.dataset.label_values[np.argmax(frame_probs_uint8_bis,
|
||||||
axis=1)].astype(np.int32)
|
axis=1)].astype(np.int32)
|
||||||
|
|
||||||
# Save some of the frame pots
|
# Save some of the frame pots
|
||||||
|
@ -528,6 +529,15 @@ class ModelTester:
|
||||||
[frame_points[:, :3], frame_labels, frame_preds],
|
[frame_points[:, :3], frame_labels, frame_preds],
|
||||||
['x', 'y', 'z', 'gt', 'pre'])
|
['x', 'y', 'z', 'gt', 'pre'])
|
||||||
|
|
||||||
|
# Also Save lbl probabilities
|
||||||
|
probpath = join(test_path, folder, filename[:-4] + '_probs.ply')
|
||||||
|
lbl_names = [test_loader.dataset.label_to_names[l]
|
||||||
|
for l in test_loader.dataset.label_values
|
||||||
|
if l not in test_loader.dataset.ignored_labels]
|
||||||
|
write_ply(probpath,
|
||||||
|
[frame_points[:, :3], frame_probs_uint8],
|
||||||
|
['x', 'y', 'z'] + lbl_names)
|
||||||
|
|
||||||
# keep frame preds in memory
|
# keep frame preds in memory
|
||||||
all_f_preds[s_ind][f_ind] = frame_preds
|
all_f_preds[s_ind][f_ind] = frame_preds
|
||||||
all_f_labels[s_ind][f_ind] = frame_labels
|
all_f_labels[s_ind][f_ind] = frame_labels
|
||||||
|
@ -575,8 +585,8 @@ class ModelTester:
|
||||||
last_display = t[-1]
|
last_display = t[-1]
|
||||||
message = 'e{:03d}-i{:04d} => {:.1f}% (timings : {:4.2f} {:4.2f} {:4.2f}) / pots {:d} => {:.1f}%'
|
message = 'e{:03d}-i{:04d} => {:.1f}% (timings : {:4.2f} {:4.2f} {:4.2f}) / pots {:d} => {:.1f}%'
|
||||||
min_pot = int(torch.floor(torch.min(test_loader.dataset.potentials)))
|
min_pot = int(torch.floor(torch.min(test_loader.dataset.potentials)))
|
||||||
pot_num = torch.sum(test_loader.dataset.potentials > min_pot).type(torch.int32).item()
|
pot_num = torch.sum(test_loader.dataset.potentials > min_pot + 0.5).type(torch.int32).item()
|
||||||
current_num = pot_num + (i0 + 1 - config.validation_size) * config.val_batch_num
|
current_num = pot_num + (i + 1 - config.validation_size) * config.val_batch_num
|
||||||
print(message.format(test_epoch, i,
|
print(message.format(test_epoch, i,
|
||||||
100 * i / config.validation_size,
|
100 * i / config.validation_size,
|
||||||
1000 * (mean_dt[0]),
|
1000 * (mean_dt[0]),
|
||||||
|
|
Loading…
Reference in a new issue