OpenShot Library | libopenshot 0.5.0
Loading...
Searching...
No Matches
sort.cpp
Go to the documentation of this file.
1// © OpenShot Studios, LLC
2//
3// SPDX-License-Identifier: LGPL-3.0-or-later
4
5#include "sort.hpp"
6
7using namespace std;
8
9// Constructor
10SortTracker::SortTracker(int max_age, int min_hits, int max_missed, double min_iou, double nms_iou_thresh, double min_conf)
11{
12 _min_hits = min_hits;
13 _max_age = max_age;
14 _max_missed = max_missed;
15 _min_iou = min_iou;
16 _nms_iou_thresh = nms_iou_thresh;
17 _min_conf = min_conf;
18 _next_id = 0;
19 alive_tracker = true;
20}
21
22// Computes IOU between two bounding boxes
23double SortTracker::GetIOU(cv::Rect_<float> bb_test, cv::Rect_<float> bb_gt)
24{
25 float in = (bb_test & bb_gt).area();
26 float un = bb_test.area() + bb_gt.area() - in;
27
28 if (un < DBL_EPSILON)
29 return 0;
30
31 return (double)(in / un);
32}
33
34// Computes centroid distance between two bounding boxes
36 cv::Rect_<float> bb_test,
37 cv::Rect_<float> bb_gt)
38{
39 float bb_test_centroid_x = (bb_test.x + bb_test.width / 2);
40 float bb_test_centroid_y = (bb_test.y + bb_test.height / 2);
41
42 float bb_gt_centroid_x = (bb_gt.x + bb_gt.width / 2);
43 float bb_gt_centroid_y = (bb_gt.y + bb_gt.height / 2);
44
45 double distance = (double)sqrt(pow(bb_gt_centroid_x - bb_test_centroid_x, 2) + pow(bb_gt_centroid_y - bb_test_centroid_y, 2));
46
47 return distance;
48}
49
50// Function to apply NMS on detections
51void apply_nms(vector<TrackingBox>& detections, double nms_iou_thresh) {
52 if (detections.empty()) return;
53
54 // Sort detections by confidence descending
55 std::sort(detections.begin(), detections.end(), [](const TrackingBox& a, const TrackingBox& b) {
56 return a.confidence > b.confidence;
57 });
58
59 vector<bool> suppressed(detections.size(), false);
60
61 for (size_t i = 0; i < detections.size(); ++i) {
62 if (suppressed[i]) continue;
63
64 for (size_t j = i + 1; j < detections.size(); ++j) {
65 if (suppressed[j]) continue;
66
67 if (detections[i].classId == detections[j].classId &&
68 SortTracker::GetIOU(detections[i].box, detections[j].box) > nms_iou_thresh) {
69 suppressed[j] = true;
70 }
71 }
72 }
73
74 // Remove suppressed detections
75 vector<TrackingBox> filtered;
76 for (size_t i = 0; i < detections.size(); ++i) {
77 if (!suppressed[i]) {
78 filtered.push_back(detections[i]);
79 }
80 }
81 detections = filtered;
82}
83
84void SortTracker::update(vector<cv::Rect> detections_cv, int frame_count, double image_diagonal, std::vector<float> confidences, std::vector<int> classIds)
85{
86 vector<TrackingBox> detections;
87 if (trackers.size() == 0) // the first frame met
88 {
89 alive_tracker = false;
90 // initialize kalman trackers using first detections.
91 for (unsigned int i = 0; i < detections_cv.size(); i++)
92 {
93 if (confidences[i] < _min_conf) continue; // filter low conf
94
95 TrackingBox tb;
96
97 tb.box = cv::Rect_<float>(detections_cv[i]);
98 tb.classId = classIds[i];
99 tb.confidence = confidences[i];
100 detections.push_back(tb);
101
102 KalmanTracker trk = KalmanTracker(detections.back().box, detections.back().confidence, detections.back().classId, _next_id++);
103 trackers.push_back(trk);
104 }
105 return;
106 }
107 else
108 {
109 for (unsigned int i = 0; i < detections_cv.size(); i++)
110 {
111 if (confidences[i] < _min_conf) continue; // filter low conf
112
113 TrackingBox tb;
114 tb.box = cv::Rect_<float>(detections_cv[i]);
115 tb.classId = classIds[i];
116 tb.confidence = confidences[i];
117 detections.push_back(tb);
118 }
119
120 // Apply NMS to remove duplicates
121 apply_nms(detections, _nms_iou_thresh);
122
123 for (auto it = frameTrackingResult.begin(); it != frameTrackingResult.end(); it++)
124 {
125 int frame_age = frame_count - it->frame;
126 if (frame_age >= _max_age || frame_age < 0)
127 {
128 dead_trackers_id.push_back(it->id);
129 }
130 }
131 }
132
134 // 3.1. get predicted locations from existing trackers.
135 predictedBoxes.clear();
136 for (unsigned int i = 0; i < trackers.size();)
137 {
138 cv::Rect_<float> pBox = trackers[i].predict();
139 if (pBox.x >= 0 && pBox.y >= 0)
140 {
141 predictedBoxes.push_back(pBox);
142 i++;
143 continue;
144 }
145 trackers.erase(trackers.begin() + i);
146 }
147
148 trkNum = predictedBoxes.size();
149 detNum = detections.size();
150
151 cost_matrix.clear();
152 cost_matrix.resize(trkNum, vector<double>(detNum, 0));
153
154 for (unsigned int i = 0; i < trkNum; i++) // compute cost matrix using 1 - IOU with gating
155 {
156 for (unsigned int j = 0; j < detNum; j++)
157 {
158 double iou = GetIOU(predictedBoxes[i], detections[j].box);
159 double dist = GetCentroidsDistance(predictedBoxes[i], detections[j].box) / image_diagonal;
160 if (trackers[i].classId != detections[j].classId || dist > max_centroid_dist_norm)
161 {
162 cost_matrix[i][j] = 1e9; // large cost for gating
163 }
164 else
165 {
166 cost_matrix[i][j] = 1 - iou + (1 - detections[j].confidence) * 0.1; // slight penalty for low conf
167 }
168 }
169 }
170
171 HungarianAlgorithm HungAlgo;
172 assignment.clear();
173 HungAlgo.Solve(cost_matrix, assignment);
174 // find matches, unmatched_detections and unmatched_predictions
175 unmatchedTrajectories.clear();
176 unmatchedDetections.clear();
177 allItems.clear();
178 matchedItems.clear();
179
180 if (detNum > trkNum) // there are unmatched detections
181 {
182 for (unsigned int n = 0; n < detNum; n++)
183 allItems.insert(n);
184
185 for (unsigned int i = 0; i < trkNum; ++i)
186 matchedItems.insert(assignment[i]);
187
188 set_difference(allItems.begin(), allItems.end(),
189 matchedItems.begin(), matchedItems.end(),
190 insert_iterator<set<int>>(unmatchedDetections, unmatchedDetections.begin()));
191 }
192 else if (detNum < trkNum) // there are unmatched trajectory/predictions
193 {
194 for (unsigned int i = 0; i < trkNum; ++i)
195 if (assignment[i] == -1) // unassigned label will be set as -1 in the assignment algorithm
196 unmatchedTrajectories.insert(i);
197 }
198 else
199 ;
200
201 // filter out matched with low IOU
202 matchedPairs.clear();
203 for (unsigned int i = 0; i < trkNum; ++i)
204 {
205 if (assignment[i] == -1) // pass over invalid values
206 continue;
207 if (cost_matrix[i][assignment[i]] > 1 - _min_iou)
208 {
209 unmatchedTrajectories.insert(i);
211 }
212 else
213 matchedPairs.push_back(cv::Point(i, assignment[i]));
214 }
215
216 for (unsigned int i = 0; i < matchedPairs.size(); i++)
217 {
218 int trkIdx = matchedPairs[i].x;
219 int detIdx = matchedPairs[i].y;
220 trackers[trkIdx].update(detections[detIdx].box);
221 trackers[trkIdx].classId = detections[detIdx].classId;
222 trackers[trkIdx].confidence = detections[detIdx].confidence;
223 }
224
225 // create and initialise new trackers for unmatched detections
226 for (auto umd : unmatchedDetections)
227 {
228 KalmanTracker tracker = KalmanTracker(detections[umd].box, detections[umd].confidence, detections[umd].classId, _next_id++);
229 trackers.push_back(tracker);
230 }
231
232 for (auto it2 = dead_trackers_id.begin(); it2 != dead_trackers_id.end(); it2++)
233 {
234 for (unsigned int i = 0; i < trackers.size();)
235 {
236 if (trackers[i].m_id == (*it2))
237 {
238 trackers.erase(trackers.begin() + i);
239 continue;
240 }
241 i++;
242 }
243 }
244
245 // get trackers' output
246 frameTrackingResult.clear();
247 for (unsigned int i = 0; i < trackers.size();)
248 {
249 if ((trackers[i].m_hits >= _min_hits && trackers[i].m_time_since_update <= _max_missed) ||
250 frame_count <= _min_hits)
251 {
252 alive_tracker = true;
253 TrackingBox res;
254 res.box = trackers[i].get_state();
255 res.id = trackers[i].m_id;
256 res.frame = frame_count;
257 res.classId = trackers[i].classId;
258 res.confidence = trackers[i].confidence;
259 frameTrackingResult.push_back(res);
260 }
261
262 // remove dead tracklet
263 if (trackers[i].m_time_since_update >= _max_age)
264 {
265 trackers.erase(trackers.begin() + i);
266 continue;
267 }
268 i++;
269 }
270}
double Solve(std::vector< std::vector< double > > &DistMatrix, std::vector< int > &Assignment)
Definition Hungarian.cpp:26
This class represents the internel state of individual tracked objects observed as bounding box.
unsigned int trkNum
Definition sort.hpp:60
double max_centroid_dist_norm
Definition sort.hpp:46
std::vector< int > dead_trackers_id
Definition sort.hpp:58
double _min_iou
Definition sort.hpp:65
std::vector< TrackingBox > frameTrackingResult
Definition sort.hpp:57
std::set< int > unmatchedDetections
Definition sort.hpp:51
std::vector< std::vector< double > > cost_matrix
Definition sort.hpp:49
int _max_age
Definition sort.hpp:63
double _nms_iou_thresh
Definition sort.hpp:66
double _min_conf
Definition sort.hpp:67
std::vector< cv::Point > matchedPairs
Definition sort.hpp:55
SortTracker(int max_age=50, int min_hits=5, int max_missed=7, double min_iou=0.1, double nms_iou_thresh=0.5, double min_conf=0.3)
Definition sort.cpp:10
double GetCentroidsDistance(cv::Rect_< float > bb_test, cv::Rect_< float > bb_gt)
Definition sort.cpp:35
std::vector< KalmanTracker > trackers
Definition sort.hpp:44
std::vector< cv::Rect_< float > > predictedBoxes
Definition sort.hpp:48
int _min_hits
Definition sort.hpp:62
void update(std::vector< cv::Rect > detection, int frame_count, double image_diagonal, std::vector< float > confidences, std::vector< int > classIds)
Definition sort.cpp:84
std::set< int > allItems
Definition sort.hpp:53
unsigned int detNum
Definition sort.hpp:61
bool alive_tracker
Definition sort.hpp:69
std::vector< int > assignment
Definition sort.hpp:50
std::set< int > matchedItems
Definition sort.hpp:54
std::set< int > unmatchedTrajectories
Definition sort.hpp:52
int _max_missed
Definition sort.hpp:64
unsigned int _next_id
Definition sort.hpp:68
static double GetIOU(cv::Rect_< float > bb_test, cv::Rect_< float > bb_gt)
Definition sort.cpp:23
void apply_nms(vector< TrackingBox > &detections, double nms_iou_thresh)
Definition sort.cpp:51
cv::Rect_< float > box
Definition sort.hpp:28
int frame
Definition sort.hpp:24
float confidence
Definition sort.hpp:25
int classId
Definition sort.hpp:26