Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions train/IOUEval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def reset(self):
self.per_class_acc = np.zeros(self.nClasses, dtype=np.float32)
self.per_class_iu = np.zeros(self.nClasses, dtype=np.float32)
self.mIOU = 0
self.batchCount = 1
self.batchCount = 0

def fast_hist(self, a, b):
k = (a >= 0) & (a < self.nClasses)
Expand All @@ -24,11 +24,21 @@ def compute_hist(self, predict, gth):
return hist

def addBatch(self, predict, gth):
predict = predict.cpu().numpy().flatten()
gth = gth.cpu().numpy().flatten()
if isinstance(predict, np.ndarray):
predict = predict.flatten()
gth = gth.flatten()
elif isinstance(predict, torch.Tensor):
predict = predict.cpu().numpy().flatten()
gth = gth.cpu().numpy().flatten()

epsilon = 0.00000001
if self.batchCount == 0:
self.hist = self.compute_hist(predict, gth)
else:
self.hist += self.compute_hist(predict, gth)
hist = self.compute_hist(predict, gth)
# hist(0) : TP + FN
# hist(1) : TP + FP
overall_acc = np.diag(hist).sum() / (hist.sum() + epsilon)
per_class_acc = np.diag(hist) / (hist.sum(1) + epsilon)
per_class_iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist) + epsilon)
Expand All @@ -39,11 +49,13 @@ def addBatch(self, predict, gth):
self.per_class_iu += per_class_iu
self.mIOU += mIou
self.batchCount += 1
return hist

def getMetric(self):
overall_acc = self.overall_acc/self.batchCount
per_class_acc = self.per_class_acc / self.batchCount
per_class_iu = self.per_class_iu / self.batchCount
mIOU = self.mIOU / self.batchCount

return overall_acc, per_class_acc, per_class_iu, mIOU
epsilon = 0.00000001
overall_acc = np.diag(self.hist).sum() / (self.hist.sum() + epsilon)
per_class_acc = np.diag(self.hist) / (self.hist.sum(1) + epsilon)
per_class_iu = np.diag(self.hist) / (self.hist.sum(1) + self.hist.sum(0) - np.diag(self.hist) + epsilon)
mIOU = np.nanmean(per_class_iu)
return overall_acc, per_class_acc, per_class_iu, mIOU