Utils.h 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. #ifndef __TRT_UTILS_H_
  2. #define __TRT_UTILS_H_
  3. #include <iostream>
  4. #include <vector>
  5. #include <algorithm>
  6. #include <cudnn.h>
  7. #ifndef CUDA_CHECK
  8. #define CUDA_CHECK(callstr) \
  9. { \
  10. cudaError_t error_code = callstr; \
  11. if (error_code != cudaSuccess) { \
  12. std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \
  13. assert(0); \
  14. } \
  15. }
  16. #endif
  17. namespace Tn
  18. {
  19. class Profiler : public nvinfer1::IProfiler
  20. {
  21. public:
  22. void printLayerTimes(int itrationsTimes)
  23. {
  24. float totalTime = 0;
  25. for (size_t i = 0; i < mProfile.size(); i++)
  26. {
  27. printf("%-40.40s %4.3fms\n", mProfile[i].first.c_str(), mProfile[i].second / itrationsTimes);
  28. totalTime += mProfile[i].second;
  29. }
  30. printf("Time over all layers: %4.3f\n", totalTime / itrationsTimes);
  31. }
  32. private:
  33. typedef std::pair<std::string, float> Record;
  34. std::vector<Record> mProfile;
  35. virtual void reportLayerTime(const char* layerName, float ms)
  36. {
  37. auto record = std::find_if(mProfile.begin(), mProfile.end(), [&](const Record& r){ return r.first == layerName; });
  38. if (record == mProfile.end())
  39. mProfile.push_back(std::make_pair(layerName, ms));
  40. else
  41. record->second += ms;
  42. }
  43. };
  44. //Logger for TensorRT info/warning/errors
  45. class Logger : public nvinfer1::ILogger
  46. {
  47. public:
  48. Logger(): Logger(Severity::kWARNING) {}
  49. Logger(Severity severity): reportableSeverity(severity) {}
  50. void log(Severity severity, const char* msg) override
  51. {
  52. // suppress messages with severity enum value greater than the reportable
  53. if (severity > reportableSeverity) return;
  54. switch (severity)
  55. {
  56. case Severity::kINTERNAL_ERROR: std::cerr << "INTERNAL_ERROR: "; break;
  57. case Severity::kERROR: std::cerr << "ERROR: "; break;
  58. case Severity::kWARNING: std::cerr << "WARNING: "; break;
  59. case Severity::kINFO: std::cerr << "INFO: "; break;
  60. default: std::cerr << "UNKNOWN: "; break;
  61. }
  62. std::cerr << msg << std::endl;
  63. }
  64. Severity reportableSeverity{Severity::kWARNING};
  65. };
  66. template<typename T>
  67. void write(char*& buffer, const T& val)
  68. {
  69. *reinterpret_cast<T*>(buffer) = val;
  70. buffer += sizeof(T);
  71. }
  72. template<typename T>
  73. void read(const char*& buffer, T& val)
  74. {
  75. val = *reinterpret_cast<const T*>(buffer);
  76. buffer += sizeof(T);
  77. }
  78. }
  79. #endif