Adding classification test method
This commit is contained in:
parent
9bae9a3a2a
commit
bbe199bb60
|
@ -81,6 +81,98 @@ class ModelTester:
|
||||||
# Test main methods
|
# Test main methods
|
||||||
# ------------------------------------------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def classification_test(self, net, test_loader, config, num_votes=100, debug=False):
|
||||||
|
|
||||||
|
############
|
||||||
|
# Initialize
|
||||||
|
############
|
||||||
|
|
||||||
|
# Choose test smoothing parameter (0 for no smothing, 0.99 for big smoothing)
|
||||||
|
softmax = torch.nn.Softmax(1)
|
||||||
|
|
||||||
|
# Number of classes including ignored labels
|
||||||
|
nc_tot = test_loader.dataset.num_classes
|
||||||
|
|
||||||
|
# Number of classes predicted by the model
|
||||||
|
nc_model = config.num_classes
|
||||||
|
|
||||||
|
# Initiate global prediction over test clouds
|
||||||
|
self.test_probs = np.zeros((test_loader.dataset.num_models, nc_model))
|
||||||
|
self.test_counts = np.zeros((test_loader.dataset.num_models, nc_model))
|
||||||
|
|
||||||
|
t = [time.time()]
|
||||||
|
mean_dt = np.zeros(1)
|
||||||
|
last_display = time.time()
|
||||||
|
while np.min(self.test_counts) < num_votes:
|
||||||
|
|
||||||
|
# Run model on all test examples
|
||||||
|
# ******************************
|
||||||
|
|
||||||
|
# Initiate result containers
|
||||||
|
probs = []
|
||||||
|
targets = []
|
||||||
|
obj_inds = []
|
||||||
|
|
||||||
|
# Start validation loop
|
||||||
|
for batch in test_loader:
|
||||||
|
|
||||||
|
# New time
|
||||||
|
t = t[-1:]
|
||||||
|
t += [time.time()]
|
||||||
|
|
||||||
|
if 'cuda' in self.device.type:
|
||||||
|
batch.to(self.device)
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
outputs = net(batch, config)
|
||||||
|
|
||||||
|
# Get probs and labels
|
||||||
|
probs += [softmax(outputs).cpu().detach().numpy()]
|
||||||
|
targets += [batch.labels.cpu().numpy()]
|
||||||
|
obj_inds += [batch.model_inds.cpu().numpy()]
|
||||||
|
|
||||||
|
if 'cuda' in self.device.type:
|
||||||
|
torch.cuda.synchronize(self.device)
|
||||||
|
|
||||||
|
# Average timing
|
||||||
|
t += [time.time()]
|
||||||
|
mean_dt = 0.95 * mean_dt + 0.05 * (np.array(t[1:]) - np.array(t[:-1]))
|
||||||
|
|
||||||
|
# Display
|
||||||
|
if (t[-1] - last_display) > 1.0:
|
||||||
|
last_display = t[-1]
|
||||||
|
message = 'Test vote {:.0f} : {:.1f}% (timings : {:4.2f} {:4.2f})'
|
||||||
|
print(message.format(np.min(self.test_counts),
|
||||||
|
100 * len(obj_inds) / config.validation_size,
|
||||||
|
1000 * (mean_dt[0]),
|
||||||
|
1000 * (mean_dt[1])))
|
||||||
|
# Stack all validation predictions
|
||||||
|
probs = np.vstack(probs)
|
||||||
|
targets = np.hstack(targets)
|
||||||
|
obj_inds = np.hstack(obj_inds)
|
||||||
|
|
||||||
|
if np.any(test_loader.dataset.input_labels[obj_inds] != targets):
|
||||||
|
raise ValueError('wrong object indices')
|
||||||
|
|
||||||
|
# Compute incremental average (predictions are always ordered)
|
||||||
|
self.test_counts[obj_inds] += 1
|
||||||
|
self.test_probs[obj_inds] += (probs - self.test_probs[obj_inds]) / (self.test_counts[obj_inds])
|
||||||
|
|
||||||
|
# Save/Display temporary results
|
||||||
|
# ******************************
|
||||||
|
|
||||||
|
test_labels = np.array(test_loader.dataset.label_values)
|
||||||
|
|
||||||
|
# Compute classification results
|
||||||
|
C1 = fast_confusion(test_loader.dataset.input_labels,
|
||||||
|
np.argmax(self.test_probs, axis=1),
|
||||||
|
test_labels)
|
||||||
|
|
||||||
|
ACC = 100 * np.sum(np.diag(C1)) / (np.sum(C1) + 1e-6)
|
||||||
|
print('Test Accuracy = {:.1f}%'.format(ACC))
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
def cloud_segmentation_test(self, net, test_loader, config, num_votes=100, debug=False):
|
def cloud_segmentation_test(self, net, test_loader, config, num_votes=100, debug=False):
|
||||||
"""
|
"""
|
||||||
Test method for cloud segmentation models
|
Test method for cloud segmentation models
|
||||||
|
|
Loading…
Reference in a new issue