TIF_E41201615/lib/utils/detection/nms.dart

116 lines
3.4 KiB
Dart

import 'dart:math';
(List<int>, List<List<double>>, List<double>) nms(List<List<double>> rawOutput,
{double confidenceThreshold = 0.7, double iouThreshold = 0.4}) {
List<int> bestClasses = [];
List<double> bestScores = [];
List<int> boxesToSave = [];
// Take the argmax to the determine the best classes and scores
for (int i = 0; i < 8400; i++) {
double bestScore = 0;
int bestCls = -1;
for (int j = 4; j < 10; j++) {
double clsScore = rawOutput[j][i];
if (clsScore > bestScore) {
bestScore = clsScore;
bestCls = j - 4;
}
}
if (bestScore > confidenceThreshold) {
bestClasses.add(bestCls);
bestScores.add(bestScore);
boxesToSave.add(i);
}
}
// Get rid of boxes below confidence threshold
List<List<double>> candidateBoxes = [];
for (var index in boxesToSave) {
List<double> savedBox = [];
for (int i = 0; i < 4; i++) {
savedBox.add(rawOutput[i][index]);
}
candidateBoxes.add(savedBox);
}
var sortedBestScores = List.from(bestScores);
sortedBestScores.sort((a, b) => -a.compareTo(b));
List<int> argSortList =
sortedBestScores.map((e) => bestScores.indexOf(e)).toList();
List<int> sortedBestClasses = [];
List<List<double>> sortedCandidateBoxes = [];
for (var index in argSortList) {
sortedBestClasses.add(bestClasses[index]);
sortedCandidateBoxes.add(candidateBoxes[index]);
}
List<List<double>> finalBboxes = [];
List<double> finalScores = [];
List<int> finalClasses = [];
while (sortedCandidateBoxes.isNotEmpty) {
var bbox1xywh = sortedCandidateBoxes.removeAt(0);
finalBboxes.add(bbox1xywh);
var bbox1xyxy = xywh2xyxy(bbox1xywh);
finalScores.add(sortedBestScores.removeAt(0));
var class1 = sortedBestClasses.removeAt(0);
finalClasses.add(class1);
List<int> indexesToRemove = [];
for (int i = 0; i < sortedCandidateBoxes.length; i++) {
if (class1 == sortedBestClasses[i]) {
if (computeIou(bbox1xyxy, xywh2xyxy(sortedCandidateBoxes[i])) >
iouThreshold) {
indexesToRemove.add(i);
}
}
}
for (var index in indexesToRemove.reversed) {
sortedCandidateBoxes.removeAt(index);
sortedBestClasses.removeAt(index);
sortedBestScores.removeAt(index);
}
}
return (finalClasses, finalBboxes, finalScores);
}
List<double> xywh2xyxy(List<double> bbox) {
double halfWidth = bbox[2] / 2;
double halfHeight = bbox[3] / 2;
return [
bbox[0] - halfWidth,
bbox[1] - halfHeight,
bbox[0] + halfWidth,
bbox[1] + halfHeight,
];
}
/// Computes the intersection over union between two bounding boxes encoded with
/// the xyxy format.
double computeIou(List<double> bbox1, List<double> bbox2) {
assert(bbox1[0] < bbox1[2]);
assert(bbox1[1] < bbox1[3]);
assert(bbox2[0] < bbox2[2]);
assert(bbox2[1] < bbox2[3]);
// Determine the coordinate of the intersection rectangle
double xLeft = max(bbox1[0], bbox2[0]);
double yTop = max(bbox1[1], bbox2[1]);
double xRight = min(bbox1[2], bbox2[2]);
double yBottom = min(bbox1[3], bbox2[3]);
if (xRight < xLeft || yBottom < yTop) {
return 0;
}
double intersectionArea = (xRight - xLeft) * (yBottom - yTop);
double bbox1Area = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]);
double bbox2Area = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1]);
double iou = intersectionArea / (bbox1Area + bbox2Area - intersectionArea);
assert(iou >= 0 && iou <= 1);
return iou;
}