From bbe199bb60a7e538e22212b5eb4f666660eeae7b Mon Sep 17 00:00:00 2001 From: HuguesTHOMAS Date: Tue, 11 Aug 2020 10:17:21 -0400 Subject: [PATCH] Adding classification test method --- utils/tester.py | 92 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) diff --git a/utils/tester.py b/utils/tester.py index f2be261..674e94c 100644 --- a/utils/tester.py +++ b/utils/tester.py @@ -81,6 +81,98 @@ class ModelTester: # Test main methods # ------------------------------------------------------------------------------------------------------------------ + def classification_test(self, net, test_loader, config, num_votes=100, debug=False): + + ############ + # Initialize + ############ + + # Choose test smoothing parameter (0 for no smothing, 0.99 for big smoothing) + softmax = torch.nn.Softmax(1) + + # Number of classes including ignored labels + nc_tot = test_loader.dataset.num_classes + + # Number of classes predicted by the model + nc_model = config.num_classes + + # Initiate global prediction over test clouds + self.test_probs = np.zeros((test_loader.dataset.num_models, nc_model)) + self.test_counts = np.zeros((test_loader.dataset.num_models, nc_model)) + + t = [time.time()] + mean_dt = np.zeros(1) + last_display = time.time() + while np.min(self.test_counts) < num_votes: + + # Run model on all test examples + # ****************************** + + # Initiate result containers + probs = [] + targets = [] + obj_inds = [] + + # Start validation loop + for batch in test_loader: + + # New time + t = t[-1:] + t += [time.time()] + + if 'cuda' in self.device.type: + batch.to(self.device) + + # Forward pass + outputs = net(batch, config) + + # Get probs and labels + probs += [softmax(outputs).cpu().detach().numpy()] + targets += [batch.labels.cpu().numpy()] + obj_inds += [batch.model_inds.cpu().numpy()] + + if 'cuda' in self.device.type: + torch.cuda.synchronize(self.device) + + # Average timing + t += [time.time()] + mean_dt = 0.95 * mean_dt + 0.05 * (np.array(t[1:]) - np.array(t[:-1])) + + # Display + if (t[-1] - last_display) > 1.0: + last_display = t[-1] + message = 'Test vote {:.0f} : {:.1f}% (timings : {:4.2f} {:4.2f})' + print(message.format(np.min(self.test_counts), + 100 * len(obj_inds) / config.validation_size, + 1000 * (mean_dt[0]), + 1000 * (mean_dt[1]))) + # Stack all validation predictions + probs = np.vstack(probs) + targets = np.hstack(targets) + obj_inds = np.hstack(obj_inds) + + if np.any(test_loader.dataset.input_labels[obj_inds] != targets): + raise ValueError('wrong object indices') + + # Compute incremental average (predictions are always ordered) + self.test_counts[obj_inds] += 1 + self.test_probs[obj_inds] += (probs - self.test_probs[obj_inds]) / (self.test_counts[obj_inds]) + + # Save/Display temporary results + # ****************************** + + test_labels = np.array(test_loader.dataset.label_values) + + # Compute classification results + C1 = fast_confusion(test_loader.dataset.input_labels, + np.argmax(self.test_probs, axis=1), + test_labels) + + ACC = 100 * np.sum(np.diag(C1)) / (np.sum(C1) + 1e-6) + print('Test Accuracy = {:.1f}%'.format(ACC)) + + return + def cloud_segmentation_test(self, net, test_loader, config, num_votes=100, debug=False): """ Test method for cloud segmentation models