Adding classification test method
This commit is contained in:
parent
9bae9a3a2a
commit
bbe199bb60
|
@ -81,6 +81,98 @@ class ModelTester:
|
|||
# Test main methods
|
||||
# ------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
def classification_test(self, net, test_loader, config, num_votes=100, debug=False):
|
||||
|
||||
############
|
||||
# Initialize
|
||||
############
|
||||
|
||||
# Choose test smoothing parameter (0 for no smothing, 0.99 for big smoothing)
|
||||
softmax = torch.nn.Softmax(1)
|
||||
|
||||
# Number of classes including ignored labels
|
||||
nc_tot = test_loader.dataset.num_classes
|
||||
|
||||
# Number of classes predicted by the model
|
||||
nc_model = config.num_classes
|
||||
|
||||
# Initiate global prediction over test clouds
|
||||
self.test_probs = np.zeros((test_loader.dataset.num_models, nc_model))
|
||||
self.test_counts = np.zeros((test_loader.dataset.num_models, nc_model))
|
||||
|
||||
t = [time.time()]
|
||||
mean_dt = np.zeros(1)
|
||||
last_display = time.time()
|
||||
while np.min(self.test_counts) < num_votes:
|
||||
|
||||
# Run model on all test examples
|
||||
# ******************************
|
||||
|
||||
# Initiate result containers
|
||||
probs = []
|
||||
targets = []
|
||||
obj_inds = []
|
||||
|
||||
# Start validation loop
|
||||
for batch in test_loader:
|
||||
|
||||
# New time
|
||||
t = t[-1:]
|
||||
t += [time.time()]
|
||||
|
||||
if 'cuda' in self.device.type:
|
||||
batch.to(self.device)
|
||||
|
||||
# Forward pass
|
||||
outputs = net(batch, config)
|
||||
|
||||
# Get probs and labels
|
||||
probs += [softmax(outputs).cpu().detach().numpy()]
|
||||
targets += [batch.labels.cpu().numpy()]
|
||||
obj_inds += [batch.model_inds.cpu().numpy()]
|
||||
|
||||
if 'cuda' in self.device.type:
|
||||
torch.cuda.synchronize(self.device)
|
||||
|
||||
# Average timing
|
||||
t += [time.time()]
|
||||
mean_dt = 0.95 * mean_dt + 0.05 * (np.array(t[1:]) - np.array(t[:-1]))
|
||||
|
||||
# Display
|
||||
if (t[-1] - last_display) > 1.0:
|
||||
last_display = t[-1]
|
||||
message = 'Test vote {:.0f} : {:.1f}% (timings : {:4.2f} {:4.2f})'
|
||||
print(message.format(np.min(self.test_counts),
|
||||
100 * len(obj_inds) / config.validation_size,
|
||||
1000 * (mean_dt[0]),
|
||||
1000 * (mean_dt[1])))
|
||||
# Stack all validation predictions
|
||||
probs = np.vstack(probs)
|
||||
targets = np.hstack(targets)
|
||||
obj_inds = np.hstack(obj_inds)
|
||||
|
||||
if np.any(test_loader.dataset.input_labels[obj_inds] != targets):
|
||||
raise ValueError('wrong object indices')
|
||||
|
||||
# Compute incremental average (predictions are always ordered)
|
||||
self.test_counts[obj_inds] += 1
|
||||
self.test_probs[obj_inds] += (probs - self.test_probs[obj_inds]) / (self.test_counts[obj_inds])
|
||||
|
||||
# Save/Display temporary results
|
||||
# ******************************
|
||||
|
||||
test_labels = np.array(test_loader.dataset.label_values)
|
||||
|
||||
# Compute classification results
|
||||
C1 = fast_confusion(test_loader.dataset.input_labels,
|
||||
np.argmax(self.test_probs, axis=1),
|
||||
test_labels)
|
||||
|
||||
ACC = 100 * np.sum(np.diag(C1)) / (np.sum(C1) + 1e-6)
|
||||
print('Test Accuracy = {:.1f}%'.format(ACC))
|
||||
|
||||
return
|
||||
|
||||
def cloud_segmentation_test(self, net, test_loader, config, num_votes=100, debug=False):
|
||||
"""
|
||||
Test method for cloud segmentation models
|
||||
|
|
Loading…
Reference in a new issue