diff --git a/utils/tester.py b/utils/tester.py index 333625b..f2be261 100644 --- a/utils/tester.py +++ b/utils/tester.py @@ -371,7 +371,7 @@ class ModelTester: return - def slam_segmentation_test(self, net, test_loader, config, num_votes=100, debug=False): + def slam_segmentation_test(self, net, test_loader, config, num_votes=100, debug=True): """ Test method for slam segmentation models """ @@ -502,12 +502,13 @@ class ModelTester: if test_loader.dataset.set == 'validation': # Insert false columns for ignored labels + frame_probs_uint8_bis = frame_probs_uint8.copy() for l_ind, label_value in enumerate(test_loader.dataset.label_values): if label_value in test_loader.dataset.ignored_labels: - frame_probs_uint8 = np.insert(frame_probs_uint8, l_ind, 0, axis=1) + frame_probs_uint8_bis = np.insert(frame_probs_uint8_bis, l_ind, 0, axis=1) # Predicted labels - frame_preds = test_loader.dataset.label_values[np.argmax(frame_probs_uint8, + frame_preds = test_loader.dataset.label_values[np.argmax(frame_probs_uint8_bis, axis=1)].astype(np.int32) # Save some of the frame pots @@ -528,6 +529,15 @@ class ModelTester: [frame_points[:, :3], frame_labels, frame_preds], ['x', 'y', 'z', 'gt', 'pre']) + # Also Save lbl probabilities + probpath = join(test_path, folder, filename[:-4] + '_probs.ply') + lbl_names = [test_loader.dataset.label_to_names[l] + for l in test_loader.dataset.label_values + if l not in test_loader.dataset.ignored_labels] + write_ply(probpath, + [frame_points[:, :3], frame_probs_uint8], + ['x', 'y', 'z'] + lbl_names) + # keep frame preds in memory all_f_preds[s_ind][f_ind] = frame_preds all_f_labels[s_ind][f_ind] = frame_labels @@ -575,8 +585,8 @@ class ModelTester: last_display = t[-1] message = 'e{:03d}-i{:04d} => {:.1f}% (timings : {:4.2f} {:4.2f} {:4.2f}) / pots {:d} => {:.1f}%' min_pot = int(torch.floor(torch.min(test_loader.dataset.potentials))) - pot_num = torch.sum(test_loader.dataset.potentials > min_pot).type(torch.int32).item() - current_num = pot_num + (i0 + 1 - config.validation_size) * config.val_batch_num + pot_num = torch.sum(test_loader.dataset.potentials > min_pot + 0.5).type(torch.int32).item() + current_num = pot_num + (i + 1 - config.validation_size) * config.val_batch_num print(message.format(test_epoch, i, 100 * i / config.validation_size, 1000 * (mean_dt[0]),