detect_obstacle.cpp 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. #include <set>
  2. #include "detect_obstacle.h"
  3. namespace od{
  4. // Computes IOU between two bounding boxes
  5. double GetIOU(Rect_<float> bb_test, Rect_<float> bb_gt)
  6. {
  7. float in = (bb_test & bb_gt).area();
  8. float un = bb_test.area() + bb_gt.area() - in;
  9. if (un < DBL_EPSILON)
  10. return 0;
  11. return (double)(in / un);
  12. }
  13. //tracking obstacle
  14. bool TrackObstacle(int frame_count,vector<KalmanTracker> &trackers,vector<bbox_t> &outs,vector<od::TrackingBox> &track_result)
  15. {
  16. // variables used in the for-loop
  17. vector<Rect_<float>> predictedBoxes;
  18. vector<vector<double>> iouMatrix;
  19. vector<int> assignment;
  20. set<int> unmatchedDetections;
  21. set<int> unmatchedTrajectories;
  22. set<int> allItems;
  23. set<int> matchedItems;
  24. vector<cv::Point> matchedPairs;
  25. unsigned int trkNum = 0;
  26. unsigned int detNum = 0;
  27. vector<od::DetectBox> detect_outs;
  28. //bbox_t to Detect_box
  29. for(unsigned int i=0;i<outs.size();i++)
  30. {
  31. od::DetectBox detect_temp;
  32. detect_temp.class_id = outs[i].obj_id;
  33. detect_temp.prob = outs[i].prob;
  34. float tpx = outs[i].x;
  35. float tpy = outs[i].y;
  36. float tpw = outs[i].w;
  37. float tph = outs[i].h;
  38. //detect_temp.box = Rect_<float>(Point_<float>(tpx, tpy),Point_<float>(tpx + tpw, tpy + tph));
  39. detect_temp.box = Rect_<float>(tpx,tpy,tpw,tph);
  40. detect_outs.push_back(detect_temp);
  41. }
  42. //tracking
  43. if (trackers.size() == 0) // the first frame met
  44. {
  45. // initialize kalman trackers using first detections.
  46. for (unsigned int i = 0; i < outs.size(); i++)
  47. {
  48. KalmanTracker trk = KalmanTracker(detect_outs[i].box,
  49. detect_outs[i].class_id,
  50. detect_outs[i].prob);
  51. trackers.push_back(trk);
  52. }
  53. return false;
  54. }
  55. ///////////////////////////////////////
  56. // 3.1. get predicted locations from existing trackers.
  57. predictedBoxes.clear();
  58. for (auto it = trackers.begin(); it != trackers.end();)
  59. {
  60. Rect_<float> pBox = (*it).predict();
  61. if (pBox.x >= 0 && pBox.y >= 0)
  62. {
  63. predictedBoxes.push_back(pBox);
  64. it++;
  65. }
  66. else
  67. {
  68. cerr << "Box invalid at frame: " << frame_count <<" id "<<(*it).m_id+1<<endl;
  69. it = trackers.erase(it);
  70. }
  71. }
  72. if (trackers.size() == 0 || detect_outs.size() == 0) return false;
  73. ///////////////////////////////////////
  74. // 3.2. associate detections to tracked object (both represented as bounding boxes)
  75. // dets : detFrameData[fi]
  76. trkNum = predictedBoxes.size();
  77. detNum = outs.size();
  78. iouMatrix.clear();
  79. iouMatrix.resize(trkNum, vector<double>(detNum, 0));
  80. for (unsigned int i = 0; i < trkNum; i++) // compute iou matrix as a distance matrix
  81. {
  82. for (unsigned int j = 0; j < detNum; j++)
  83. {
  84. // use 1-iou because the hungarian algorithm computes a minimum-cost assignment.
  85. iouMatrix[i][j] = 1 - GetIOU(predictedBoxes[i], detect_outs[j].box);
  86. }
  87. }
  88. // solve the assignment problem using hungarian algorithm.
  89. // the resulting assignment is [track(prediction) : detection], with len=preNum
  90. HungarianAlgorithm HungAlgo;
  91. assignment.clear();
  92. HungAlgo.Solve(iouMatrix, assignment);
  93. // find matches, unmatched_detections and unmatched_predictions
  94. unmatchedTrajectories.clear();
  95. unmatchedDetections.clear();
  96. allItems.clear();
  97. matchedItems.clear();
  98. if (detNum > trkNum) // there are unmatched detections
  99. {
  100. for (unsigned int n = 0; n < detNum; n++)
  101. allItems.insert(n);
  102. for (unsigned int i = 0; i < trkNum; ++i)
  103. matchedItems.insert(assignment[i]);
  104. set_difference(allItems.begin(), allItems.end(),
  105. matchedItems.begin(), matchedItems.end(),
  106. insert_iterator<set<int>>(unmatchedDetections, unmatchedDetections.begin()));
  107. }
  108. else if (detNum < trkNum) // there are unmatched trajectory/predictions
  109. {
  110. for (unsigned int i = 0; i < trkNum; ++i)
  111. if (assignment[i] == -1) // unassigned label will be set as -1 in the assignment algorithm
  112. unmatchedTrajectories.insert(i);
  113. }
  114. // filter out matched with low IOU
  115. matchedPairs.clear();
  116. for (unsigned int i = 0; i < trkNum; ++i)
  117. {
  118. if (assignment[i] == -1) // pass over invalid values
  119. continue;
  120. if (1 - iouMatrix[i][assignment[i]] < od::iouThreshold)
  121. {
  122. unmatchedTrajectories.insert(i);
  123. unmatchedDetections.insert(assignment[i]);
  124. }
  125. else
  126. matchedPairs.push_back(cv::Point(i, assignment[i]));
  127. }
  128. ///////////////////////////////////////
  129. // 3.3. updating trackers
  130. // update matched trackers with assigned detections.
  131. // each prediction is corresponding to a tracker
  132. int detIdx, trkIdx;
  133. for (unsigned int i = 0; i < matchedPairs.size(); i++)
  134. {
  135. trkIdx = matchedPairs[i].x;
  136. detIdx = matchedPairs[i].y;
  137. trackers[trkIdx].update(detect_outs[detIdx].box,
  138. detect_outs[detIdx].class_id,
  139. detect_outs[detIdx].prob);
  140. }
  141. // create and initialise new trackers for unmatched detections
  142. for (auto umd : unmatchedDetections)
  143. {
  144. KalmanTracker tracker = KalmanTracker(detect_outs[umd].box,
  145. detect_outs[umd].class_id,
  146. detect_outs[umd].prob);
  147. trackers.push_back(tracker);
  148. }
  149. #if 0
  150. //get unique trackers,merg same trackers
  151. unsigned int trackers_num = trackers.size();
  152. iouMatrix.clear();
  153. iouMatrix.resize(trackers_num, vector<double>(trackers_num, 0));
  154. for (unsigned int i = 0; i < trackers_num; i++) // compute iou matrix as a distance matrix
  155. {
  156. for (unsigned int j = 0; j < trackers_num; j++)
  157. {
  158. // use 1-iou because the hungarian algorithm computes a minimum-cost assignment.
  159. if(j==i)
  160. iouMatrix[i][j] = 1;
  161. else
  162. iouMatrix[i][j] = 1 - GetIOU(trackers[i].get_state(), trackers[j].get_state());
  163. }
  164. }
  165. // solve the assignment problem using hungarian algorithm.
  166. // the resulting assignment is [track(prediction) : detection], with len=preNum
  167. assignment.clear();
  168. HungAlgo.Solve(iouMatrix, assignment);
  169. // filter out matched with low IOU
  170. matchedPairs.clear();
  171. for (unsigned int i = 0; i < trackers_num; ++i)
  172. {
  173. if (assignment[i] == -1) // pass over invalid values
  174. continue;
  175. if (iouMatrix[i][assignment[i]] < od::iouThreshold)
  176. {
  177. matchedPairs.push_back(cv::Point(i, assignment[i]));
  178. }
  179. }
  180. int index1,index2;
  181. vector<int> delete_index;
  182. for (unsigned int i = 0; i < matchedPairs.size(); i++)
  183. {
  184. index1 = matchedPairs[i].x;
  185. index2 = matchedPairs[i].y;
  186. if(index1 >= index2)
  187. continue;
  188. if((trackers[index1].m_id > trackers[index2].m_id) && (trackers[index1].m_class_history.size()>0))
  189. {
  190. trackers[index1].m_id = trackers[index2].m_id;
  191. trackers[index1].m_class_history.insert(trackers[index1].m_class_history.begin(),
  192. trackers[index2].m_class_history.begin(),trackers[index2].m_class_history.end());
  193. delete_index.push_back(index2);
  194. }
  195. else if((trackers[index2].m_id > trackers[index1].m_id) && (trackers[index2].m_class_history.size()>0))
  196. {
  197. trackers[index2].m_id = trackers[index1].m_id;
  198. trackers[index2].m_class_history.insert(trackers[index2].m_class_history.begin(),
  199. trackers[index1].m_class_history.begin(),trackers[index1].m_class_history.end());
  200. delete_index.push_back(index1);
  201. }
  202. }
  203. for(unsigned int i = 0; i < delete_index.size(); i++)
  204. {
  205. int idx = delete_index[i] - i;
  206. trackers.erase(trackers.begin() + idx);
  207. }
  208. #endif
  209. // get trackers' output
  210. track_result.clear();
  211. for (auto it = trackers.begin(); it != trackers.end();)
  212. {
  213. if (((*it).m_time_since_update <= od::max_age) &&
  214. ((*it).m_hit_streak >= od::min_hits || frame_count <= od::min_hits))
  215. {
  216. od::TrackingBox res;
  217. res.box = (*it).get_state();
  218. res.id = (*it).m_id + 1;
  219. res.frame = frame_count;
  220. res.class_id = (*it).m_class_id;
  221. res.prob = (*it).m_prob;
  222. res.class_history = (*it).m_class_history;
  223. track_result.push_back(res);
  224. it++;
  225. }
  226. else
  227. it ++;
  228. //remove dead tracklet
  229. if(it != trackers.end() && (*it).m_time_since_update > od::max_age)
  230. {
  231. it = trackers.erase(it);
  232. }
  233. }
  234. if(track_result.size()>0)
  235. return true;
  236. else return false;
  237. }