OpenShot Library | libopenshot 0.5.0
Loading...
Searching...
No Matches
ObjectDetection.cpp
Go to the documentation of this file.
1
10// Copyright (c) 2008-2019 OpenShot Studios, LLC
11//
12// SPDX-License-Identifier: LGPL-3.0-or-later
13
14#include <fstream>
15#include <iostream>
16#include <algorithm>
17
19#include "effects/Tracker.h"
20#include "Exceptions.h"
21#include "Timeline.h"
22#include "objdetectdata.pb.h"
23
24#include <QImage>
25#include <QPainter>
26#include <QRectF>
27#include <QString>
28#include <QStringList>
29using namespace std;
30using namespace openshot;
31
32
33// Default constructor
35 : display_box_text(1.0)
36 , display_boxes(1.0)
37{
38 // Init effect metadata
39 init_effect_details();
40
41 // We haven’t loaded any protobuf yet, so there's nothing to pick.
43}
44
45// Init effect settings
46void ObjectDetection::init_effect_details()
47{
50
52 info.class_name = "ObjectDetection";
53 info.name = "Object Detector";
54 info.description = "Detect objects through the video.";
55 info.has_audio = false;
56 info.has_video = true;
58}
59
60// This method is required for all derived classes of EffectBase, and returns a
61// modified openshot::Frame object
62std::shared_ptr<Frame> ObjectDetection::GetFrame(std::shared_ptr<Frame> frame, int64_t frame_number) {
63 // Get the frame's QImage
64 std::shared_ptr<QImage> frame_image = frame->GetImage();
65
66 // Check if frame isn't NULL
67 if(!frame_image || frame_image->isNull()) {
68 return frame;
69 }
70
71 QPainter painter(frame_image.get());
72 painter.setRenderHints(QPainter::Antialiasing | QPainter::SmoothPixmapTransform);
73
74 if (detectionsData.find(frame_number) != detectionsData.end()) {
75 DetectionData detections = detectionsData[frame_number];
76 for (int i = 0; i < detections.boxes.size(); i++) {
77 if (detections.confidences.at(i) < confidence_threshold ||
78 (!display_classes.empty() &&
79 std::find(display_classes.begin(), display_classes.end(), classNames[detections.classIds.at(i)]) == display_classes.end())) {
80 continue;
81 }
82
83 int objectId = detections.objectIds.at(i);
84 auto trackedObject_it = trackedObjects.find(objectId);
85
86 if (trackedObject_it != trackedObjects.end()) {
87 std::shared_ptr<TrackedObjectBBox> trackedObject = std::static_pointer_cast<TrackedObjectBBox>(trackedObject_it->second);
88
89 Clip* parentClip = (Clip*) trackedObject->ParentClip();
90 if (parentClip && trackedObject->Contains(frame_number) && trackedObject->visible.GetValue(frame_number) == 1) {
91 BBox trackedBox = trackedObject->GetBox(frame_number);
92 QRectF boxRect((trackedBox.cx - trackedBox.width / 2) * frame_image->width(),
93 (trackedBox.cy - trackedBox.height / 2) * frame_image->height(),
94 trackedBox.width * frame_image->width(),
95 trackedBox.height * frame_image->height());
96
97 // Get properties of tracked object (i.e. colors, stroke width, etc...)
98 std::vector<int> stroke_rgba = trackedObject->stroke.GetColorRGBA(frame_number);
99 std::vector<int> bg_rgba = trackedObject->background.GetColorRGBA(frame_number);
100 int stroke_width = trackedObject->stroke_width.GetValue(frame_number);
101 float stroke_alpha = trackedObject->stroke_alpha.GetValue(frame_number);
102 float bg_alpha = trackedObject->background_alpha.GetValue(frame_number);
103 float bg_corner = trackedObject->background_corner.GetValue(frame_number);
104
105 // Set the pen for the border
106 QPen pen(QColor(stroke_rgba[0], stroke_rgba[1], stroke_rgba[2], 255 * stroke_alpha));
107 pen.setWidth(stroke_width);
108 painter.setPen(pen);
109
110 // Set the brush for the background
111 QBrush brush(QColor(bg_rgba[0], bg_rgba[1], bg_rgba[2], 255 * bg_alpha));
112 painter.setBrush(brush);
113
114 if (display_boxes.GetValue(frame_number) == 1 && trackedObject->draw_box.GetValue(frame_number) == 1) {
115 // Only draw boxes if both properties are set to YES (draw all boxes, and draw box of the selected box)
116 painter.drawRoundedRect(boxRect, bg_corner, bg_corner);
117 }
118
119 if(display_box_text.GetValue(frame_number) == 1) {
120 // Draw text label above bounding box
121 // Get the confidence and classId for the current detection
122 int classId = detections.classIds.at(i);
123
124 // Get the label for the class name and its confidence
125 QString label = QString::number(objectId);
126 if (!classNames.empty()) {
127 label = QString::fromStdString(classNames[classId]) + ":" + label;
128 }
129
130 // Set up the painter, font, and pen
131 QFont font;
132 font.setPixelSize(14);
133 painter.setFont(font);
134
135 // Calculate the size of the text
136 QFontMetrics fontMetrics(font);
137 QSize labelSize = fontMetrics.size(Qt::TextSingleLine, label);
138
139 // Define the top left point of the rectangle
140 double left = boxRect.center().x() - (labelSize.width() / 2.0);
141 double top = std::max(static_cast<int>(boxRect.top()), labelSize.height()) - 4.0;
142
143 // Draw the text
144 painter.drawText(QPointF(left, top), label);
145 }
146 }
147 }
148 }
149 }
150
151 painter.end();
152
153 // The frame's QImage has been modified in place, so we just return the original frame
154 return frame;
155}
156
157// Load protobuf data file
158bool ObjectDetection::LoadObjDetectdData(std::string inputFilePath)
159{
160 // Parse the file
161 pb_objdetect::ObjDetect objMessage;
162 std::fstream input(inputFilePath, std::ios::in | std::ios::binary);
163 if (!objMessage.ParseFromIstream(&input)) {
164 std::cerr << "Failed to parse protobuf message." << std::endl;
165 return false;
166 }
167
168 // Clear out any old state
169 classNames.clear();
170 detectionsData.clear();
171 trackedObjects.clear();
172
173 // Seed colors for each class
174 std::srand(1);
175 for (int i = 0; i < objMessage.classnames_size(); ++i) {
176 classNames.push_back(objMessage.classnames(i));
177 classesColor.push_back(cv::Scalar(
178 std::rand() % 205 + 50,
179 std::rand() % 205 + 50,
180 std::rand() % 205 + 50
181 ));
182 }
183
184 // Walk every frame in the protobuf
185 for (size_t fi = 0; fi < objMessage.frame_size(); ++fi) {
186 const auto &pbFrame = objMessage.frame(fi);
187 size_t frameId = pbFrame.id();
188
189 // Buffers for DetectionData
190 std::vector<int> classIds;
191 std::vector<float> confidences;
192 std::vector<cv::Rect_<float>> boxes;
193 std::vector<int> objectIds;
194
195 // For each bounding box in this frame
196 for (int di = 0; di < pbFrame.bounding_box_size(); ++di) {
197 const auto &b = pbFrame.bounding_box(di);
198 float x = b.x(), y = b.y(), w = b.w(), h = b.h();
199 int classId = b.classid();
200 float confidence= b.confidence();
201 int objectId = b.objectid();
202
203 // Record for DetectionData
204 classIds.push_back(classId);
205 confidences.push_back(confidence);
206 boxes.emplace_back(x, y, w, h);
207 objectIds.push_back(objectId);
208
209 // Either append to an existing TrackedObjectBBox…
210 auto it = trackedObjects.find(objectId);
211 if (it != trackedObjects.end()) {
212 it->second->AddBox(frameId, x + w/2, y + h/2, w, h, 0.0);
213 }
214 else {
215 // …or create a brand-new one
216 TrackedObjectBBox tmpObj(
217 (int)classesColor[classId][0],
218 (int)classesColor[classId][1],
219 (int)classesColor[classId][2],
220 /*alpha=*/0
221 );
222 tmpObj.stroke_alpha = Keyframe(1.0);
223 tmpObj.AddBox(frameId, x + w/2, y + h/2, w, h, 0.0);
224
225 auto ptr = std::make_shared<TrackedObjectBBox>(tmpObj);
226 ptr->ParentClip(this->ParentClip());
227
228 // Prefix with effect UUID for a unique string ID
229 std::string prefix = this->Id();
230 if (!prefix.empty())
231 prefix += "-";
232 ptr->Id(prefix + std::to_string(objectId));
233 trackedObjects.emplace(objectId, ptr);
234 }
235 }
236
237 // Save the DetectionData for this frame
238 detectionsData[frameId] = DetectionData(
239 classIds, confidences, boxes, frameId, objectIds
240 );
241 }
242
243 google::protobuf::ShutdownProtobufLibrary();
244
245 // Finally, pick a default selectedObjectIndex if we have any
246 if (!trackedObjects.empty()) {
247 selectedObjectIndex = trackedObjects.begin()->first;
248 }
249
250 return true;
251}
252
253// Get the indexes and IDs of all visible objects in the given frame
254std::string ObjectDetection::GetVisibleObjects(int64_t frame_number) const{
255
256 // Initialize the JSON objects
257 Json::Value root;
258 root["visible_objects_index"] = Json::Value(Json::arrayValue);
259 root["visible_objects_id"] = Json::Value(Json::arrayValue);
260 root["visible_class_names"] = Json::Value(Json::arrayValue);
261
262 // Check if track data exists for the requested frame
263 if (detectionsData.find(frame_number) == detectionsData.end()){
264 return root.toStyledString();
265 }
266 DetectionData detections = detectionsData.at(frame_number);
267
268 // Iterate through the tracked objects
269 for(int i = 0; i<detections.boxes.size(); i++){
270 // Does not show boxes with confidence below the threshold
271 if(detections.confidences.at(i) < confidence_threshold){
272 continue;
273 }
274
275 // Get class name of tracked object
276 auto className = classNames[detections.classIds.at(i)];
277
278 // If display_classes is not empty, check if className is in it
279 if (!display_classes.empty()) {
280 auto it = std::find(display_classes.begin(), display_classes.end(), className);
281 if (it == display_classes.end()) {
282 // If not in display_classes, skip this detection
283 continue;
284 }
285 root["visible_class_names"].append(className);
286 } else {
287 // include all class names
288 root["visible_class_names"].append(className);
289 }
290
291 int objectId = detections.objectIds.at(i);
292 // Search for the object in the trackedObjects map
293 auto trackedObject = trackedObjects.find(objectId);
294
295 // Get the tracked object JSON properties for this frame
296 Json::Value trackedObjectJSON = trackedObject->second->PropertiesJSON(frame_number);
297
298 if (trackedObjectJSON["visible"]["value"].asBool() &&
299 trackedObject->second->ExactlyContains(frame_number)){
300 // Save the object's index and ID if it's visible in this frame
301 root["visible_objects_index"].append(trackedObject->first);
302 root["visible_objects_id"].append(trackedObject->second->Id());
303 }
304 }
305
306 return root.toStyledString();
307}
308
309// Generate JSON string of this object
310std::string ObjectDetection::Json() const {
311
312 // Return formatted string
313 return JsonValue().toStyledString();
314}
315
316// Generate Json::Value for this object
317Json::Value ObjectDetection::JsonValue() const {
318
319 // Create root json object
320 Json::Value root = EffectBase::JsonValue(); // get parent properties
321 root["type"] = info.class_name;
322 root["protobuf_data_path"] = protobuf_data_path;
323 root["selected_object_index"] = selectedObjectIndex;
324 root["confidence_threshold"] = confidence_threshold;
325 root["display_box_text"] = display_box_text.JsonValue();
326 root["display_boxes"] = display_boxes.JsonValue();
327
328 // Add tracked object's IDs to root
329 Json::Value objects;
330 for (auto const& trackedObject : trackedObjects){
331 Json::Value trackedObjectJSON = trackedObject.second->JsonValue();
332 // add object json
333 objects[trackedObject.second->Id()] = trackedObjectJSON;
334 }
335 root["objects"] = objects;
336
337 // return JsonValue
338 return root;
339}
340
341// Load JSON string into this object
342void ObjectDetection::SetJson(const std::string value) {
343
344 // Parse JSON string into JSON objects
345 try
346 {
347 const Json::Value root = openshot::stringToJson(value);
348 // Set all values that match
349 SetJsonValue(root);
350 }
351 catch (const std::exception& e)
352 {
353 // Error parsing JSON (or missing keys)
354 throw InvalidJSON("JSON is invalid (missing keys or invalid data types)");
355 }
356}
357
358// Load Json::Value into this object
359void ObjectDetection::SetJsonValue(const Json::Value root)
360{
361 // Parent properties
363
364 // If a protobuf path is provided, load & prefix IDs
365 if (!root["protobuf_data_path"].isNull()) {
366 std::string new_path = root["protobuf_data_path"].asString();
367 if (protobuf_data_path != new_path || trackedObjects.empty()) {
368 protobuf_data_path = new_path;
369 if (!LoadObjDetectdData(protobuf_data_path)) {
370 throw InvalidFile("Invalid protobuf data path", "");
371 }
372 }
373 }
374
375 // Selected index, thresholds, UI flags, filters, etc.
376 if (!root["selected_object_index"].isNull())
377 selectedObjectIndex = root["selected_object_index"].asInt();
378 if (!root["confidence_threshold"].isNull())
379 confidence_threshold = root["confidence_threshold"].asFloat();
380 if (!root["display_box_text"].isNull())
381 display_box_text.SetJsonValue(root["display_box_text"]);
382 if (!root["display_boxes"].isNull())
383 display_boxes.SetJsonValue(root["display_boxes"]);
384
385 if (!root["class_filter"].isNull()) {
386 class_filter = root["class_filter"].asString();
387 QStringList parts = QString::fromStdString(class_filter).split(',');
388 display_classes.clear();
389 for (auto &p : parts) {
390 auto s = p.trimmed().toLower();
391 if (!s.isEmpty()) {
392 display_classes.push_back(s.toStdString());
393 }
394 }
395 }
396
397 // Apply any per-object overrides
398 if (!root["objects"].isNull()) {
399 // Iterate over the supplied objects (indexed by id or position)
400 const auto memberNames = root["objects"].getMemberNames();
401 for (const auto& name : memberNames)
402 {
403 // Determine the numeric index of this object
404 int index = -1;
405 bool numeric_key = std::all_of(name.begin(), name.end(), ::isdigit);
406 if (numeric_key) {
407 index = std::stoi(name);
408 }
409 else
410 {
411 size_t pos = name.find_last_of('-');
412 if (pos != std::string::npos) {
413 try {
414 index = std::stoi(name.substr(pos + 1));
415 } catch (...) {
416 index = -1;
417 }
418 }
419 }
420
421 auto obj_it = trackedObjects.find(index);
422 if (obj_it != trackedObjects.end() && obj_it->second) {
423 // Update object id if provided as a non-numeric key
424 if (!numeric_key)
425 obj_it->second->Id(name);
426 obj_it->second->SetJsonValue(root["objects"][name]);
427 }
428 }
429 }
430 // Set the tracked object's ids (legacy format)
431 if (!root["objects_id"].isNull()) {
432 for (auto& kv : trackedObjects) {
433 if (!root["objects_id"][kv.first].isNull())
434 kv.second->Id(root["objects_id"][kv.first].asString());
435 }
436 }
437}
438
439// Get all properties for a specific frame
440std::string ObjectDetection::PropertiesJSON(int64_t requested_frame) const {
441
442 // Generate JSON properties list
443 Json::Value root = BasePropertiesJSON(requested_frame);
444
445 Json::Value objects;
446 if(trackedObjects.count(selectedObjectIndex) != 0){
447 auto selectedObject = trackedObjects.at(selectedObjectIndex);
448 if (selectedObject){
449 Json::Value trackedObjectJSON = selectedObject->PropertiesJSON(requested_frame);
450 // add object json
451 objects[selectedObject->Id()] = trackedObjectJSON;
452 }
453 }
454 root["objects"] = objects;
455
456 root["selected_object_index"] = add_property_json("Selected Object", selectedObjectIndex, "int", "", NULL, 0, 200, false, requested_frame);
457 root["confidence_threshold"] = add_property_json("Confidence Theshold", confidence_threshold, "float", "", NULL, 0, 1, false, requested_frame);
458 root["class_filter"] = add_property_json("Class Filter", 0.0, "string", class_filter, NULL, -1, -1, false, requested_frame);
459
460 root["display_box_text"] = add_property_json("Draw All Text", display_box_text.GetValue(requested_frame), "int", "", &display_box_text, 0, 1, false, requested_frame);
461 root["display_box_text"]["choices"].append(add_property_choice_json("Yes", true, display_box_text.GetValue(requested_frame)));
462 root["display_box_text"]["choices"].append(add_property_choice_json("No", false, display_box_text.GetValue(requested_frame)));
463
464 root["display_boxes"] = add_property_json("Draw All Boxes", display_boxes.GetValue(requested_frame), "int", "", &display_boxes, 0, 1, false, requested_frame);
465 root["display_boxes"]["choices"].append(add_property_choice_json("Yes", true, display_boxes.GetValue(requested_frame)));
466 root["display_boxes"]["choices"].append(add_property_choice_json("No", false, display_boxes.GetValue(requested_frame)));
467
468 // Return formatted string
469 return root.toStyledString();
470}
Header file for all Exception classes.
Header file for Object Detection effect class.
Header file for Timeline class.
Header file for Tracker effect class.
std::string Id() const
Get the Id of this clip object.
Definition ClipBase.h:85
Json::Value add_property_choice_json(std::string name, int value, int selected_value) const
Generate JSON choice for a property (dropdown properties)
Definition ClipBase.cpp:132
Json::Value add_property_json(std::string name, float value, std::string type, std::string memo, const Keyframe *keyframe, float min_value, float max_value, bool readonly, int64_t requested_frame) const
Generate JSON for a property.
Definition ClipBase.cpp:96
This class represents a clip (used to arrange readers on the timeline)
Definition Clip.h:89
virtual Json::Value JsonValue() const
Generate Json::Value for this object.
openshot::ClipBase * ParentClip()
Parent clip object of this effect (which can be unparented and NULL)
Json::Value BasePropertiesJSON(int64_t requested_frame) const
Generate JSON object of base properties (recommended to be used by all effects)
virtual void SetJsonValue(const Json::Value root)
Load Json::Value into this object.
EffectInfoStruct info
Information about the current effect.
Definition EffectBase.h:69
std::map< int, std::shared_ptr< openshot::TrackedObjectBase > > trackedObjects
Map of Tracked Object's by their indices (used by Effects that track objects on clips)
Definition EffectBase.h:66
Exception for files that can not be found or opened.
Definition Exceptions.h:188
Exception for invalid JSON.
Definition Exceptions.h:218
A Keyframe is a collection of Point instances, which is used to vary a number or property over time.
Definition KeyFrame.h:53
void SetJsonValue(const Json::Value root)
Load Json::Value into this object.
Definition KeyFrame.cpp:372
double GetValue(int64_t index) const
Get the value at a specific index.
Definition KeyFrame.cpp:258
Json::Value JsonValue() const
Generate Json::Value for this object.
Definition KeyFrame.cpp:339
Json::Value JsonValue() const override
Generate Json::Value for this object.
int selectedObjectIndex
Index of the Tracked Object that was selected to modify it's properties.
std::shared_ptr< Frame > GetFrame(std::shared_ptr< Frame > frame, int64_t frame_number) override
This method is required for all derived classes of EffectBase, and returns a modified openshot::Frame...
ObjectDetection()
Default constructor.
bool LoadObjDetectdData(std::string inputFilePath)
Load protobuf data file.
std::string GetVisibleObjects(int64_t frame_number) const override
Get the indexes and IDs of all visible objects in the given frame.
std::string Json() const override
Generate JSON string of this object.
std::string PropertiesJSON(int64_t requested_frame) const override
void SetJsonValue(const Json::Value root) override
Load Json::Value into this object.
void SetJson(const std::string value) override
Load JSON string into this object.
openshot::ClipBase * ParentClip()
Parent clip object of this reader (which can be unparented and NULL)
This class contains the properties of a tracked object and functions to manipulate it.
void AddBox(int64_t _frame_num, float _cx, float _cy, float _width, float _height, float _angle) override
Add a BBox to the BoxVec map.
Keyframe stroke_alpha
Stroke box opacity.
This namespace is the default namespace for all code in the openshot library.
Definition Compressor.h:29
const Json::Value stringToJson(const std::string value)
Definition Json.cpp:16
std::vector< cv::Rect_< float > > boxes
std::vector< float > confidences
std::vector< int > classIds
std::vector< int > objectIds
This struct holds the information of a bounding-box.
float cy
y-coordinate of the bounding box center
float height
bounding box height
float cx
x-coordinate of the bounding box center
float width
bounding box width
bool has_video
Determines if this effect manipulates the image of a frame.
Definition EffectBase.h:40
bool has_audio
Determines if this effect manipulates the audio of a frame.
Definition EffectBase.h:41
std::string class_name
The class name of the effect.
Definition EffectBase.h:36
std::string name
The name of the effect.
Definition EffectBase.h:37
std::string description
The description of this effect and what it does.
Definition EffectBase.h:38
bool has_tracked_object
Determines if this effect track objects through the clip.
Definition EffectBase.h:42