import 'dart:math'; (List, List>, List) nms(List> rawOutput, {double confidenceThreshold = 0.7, double iouThreshold = 0.4}) { List bestClasses = []; List bestScores = []; List boxesToSave = []; // Take the argmax to the determine the best classes and scores for (int i = 0; i < 8400; i++) { double bestScore = 0; int bestCls = -1; for (int j = 4; j < 10; j++) { double clsScore = rawOutput[j][i]; if (clsScore > bestScore) { bestScore = clsScore; bestCls = j - 4; } } if (bestScore > confidenceThreshold) { bestClasses.add(bestCls); bestScores.add(bestScore); boxesToSave.add(i); } } // Get rid of boxes below confidence threshold List> candidateBoxes = []; for (var index in boxesToSave) { List savedBox = []; for (int i = 0; i < 4; i++) { savedBox.add(rawOutput[i][index]); } candidateBoxes.add(savedBox); } var sortedBestScores = List.from(bestScores); sortedBestScores.sort((a, b) => -a.compareTo(b)); List argSortList = sortedBestScores.map((e) => bestScores.indexOf(e)).toList(); List sortedBestClasses = []; List> sortedCandidateBoxes = []; for (var index in argSortList) { sortedBestClasses.add(bestClasses[index]); sortedCandidateBoxes.add(candidateBoxes[index]); } List> finalBboxes = []; List finalScores = []; List finalClasses = []; while (sortedCandidateBoxes.isNotEmpty) { var bbox1xywh = sortedCandidateBoxes.removeAt(0); finalBboxes.add(bbox1xywh); var bbox1xyxy = xywh2xyxy(bbox1xywh); finalScores.add(sortedBestScores.removeAt(0)); var class1 = sortedBestClasses.removeAt(0); finalClasses.add(class1); List indexesToRemove = []; for (int i = 0; i < sortedCandidateBoxes.length; i++) { if (class1 == sortedBestClasses[i]) { if (computeIou(bbox1xyxy, xywh2xyxy(sortedCandidateBoxes[i])) > iouThreshold) { indexesToRemove.add(i); } } } for (var index in indexesToRemove.reversed) { sortedCandidateBoxes.removeAt(index); sortedBestClasses.removeAt(index); sortedBestScores.removeAt(index); } } return (finalClasses, finalBboxes, finalScores); } List xywh2xyxy(List bbox) { double halfWidth = bbox[2] / 2; double halfHeight = bbox[3] / 2; return [ bbox[0] - halfWidth, bbox[1] - halfHeight, bbox[0] + halfWidth, bbox[1] + halfHeight, ]; } /// Computes the intersection over union between two bounding boxes encoded with /// the xyxy format. double computeIou(List bbox1, List bbox2) { assert(bbox1[0] < bbox1[2]); assert(bbox1[1] < bbox1[3]); assert(bbox2[0] < bbox2[2]); assert(bbox2[1] < bbox2[3]); // Determine the coordinate of the intersection rectangle double xLeft = max(bbox1[0], bbox2[0]); double yTop = max(bbox1[1], bbox2[1]); double xRight = min(bbox1[2], bbox2[2]); double yBottom = min(bbox1[3], bbox2[3]); if (xRight < xLeft || yBottom < yTop) { return 0; } double intersectionArea = (xRight - xLeft) * (yBottom - yTop); double bbox1Area = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]); double bbox2Area = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1]); double iou = intersectionArea / (bbox1Area + bbox2Area - intersectionArea); assert(iou >= 0 && iou <= 1); return iou; }