laneatt.hh 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. #pragma once
  2. #include <NvInfer.h>
  3. #include <iostream>
  4. #include <memory>
  5. #include <opencv2/core.hpp>
  6. class Logger : public nvinfer1::ILogger {
  7. public:
  8. explicit Logger(Severity severity = Severity::kWARNING)
  9. : reportable_severity(severity) {}
  10. void log(Severity severity, const char* msg) noexcept {
  11. if (severity > reportable_severity) {
  12. return;
  13. }
  14. switch (severity) {
  15. case Severity::kINTERNAL_ERROR:
  16. std::cerr << "INTERNAL_ERROR: ";
  17. break;
  18. case Severity::kERROR:
  19. std::cerr << "ERROR: ";
  20. break;
  21. case Severity::kWARNING:
  22. std::cerr << "WARNING: ";
  23. break;
  24. case Severity::kINFO:
  25. std::cerr << "INFO: ";
  26. break;
  27. default:
  28. std::cerr << "UNKNOWN: ";
  29. break;
  30. }
  31. std::cerr << msg << std::endl;
  32. }
  33. Severity reportable_severity;
  34. };
  35. struct Detection {
  36. float unknown;
  37. float score;
  38. float start_y;
  39. float start_x;
  40. float length;
  41. float lane_xs[72];
  42. };
  43. class LaneATT {
  44. public:
  45. LaneATT(const std::string& plan_path);
  46. ~LaneATT();
  47. std::vector<std::vector<cv::Point2f>> DetectLane(const cv::Mat& raw_image);
  48. private:
  49. void LoadEngine(const std::string& engine_file);
  50. std::vector<std::vector<cv::Point2f>> PostProcess(cv::Mat& lane_image, cv::Mat raw_image, float conf_thresh=0.5f, float nms_thresh=50.f, int nms_topk=4);
  51. Logger g_logger_;
  52. cudaStream_t stream_;
  53. nvinfer1::ICudaEngine* engine_;
  54. nvinfer1::IExecutionContext* context_;
  55. void* buffers_[2];
  56. int buffer_size_[2];
  57. std::vector<float> image_data_;
  58. std::vector<Detection> detections_;
  59. };