trt_common.cpp 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. // Copyright 2020 Tier IV, Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include <trt_common.hpp>
  15. #if (defined(_MSC_VER) or (defined(__GNUC__) and (7 <= __GNUC_MAJOR__)))
  16. #include <filesystem>
  17. namespace fs = ::std::filesystem;
  18. #else
  19. #include <experimental/filesystem>
  20. namespace fs = ::std::experimental::filesystem;
  21. #endif
  22. #include <functional>
  23. #include <string>
  24. namespace Tn
  25. {
  26. void check_error(const ::cudaError_t e, decltype(__FILE__) f, decltype(__LINE__) n)
  27. {
  28. if (e != ::cudaSuccess) {
  29. std::stringstream s;
  30. s << ::cudaGetErrorName(e) << " (" << e << ")@" << f << "#L" << n << ": "
  31. << ::cudaGetErrorString(e);
  32. throw std::runtime_error{s.str()};
  33. }
  34. }
  35. TrtCommon::TrtCommon(
  36. std::string model_path, std::string precision, std::string input_name, std::string output_name)
  37. : model_file_path_(model_path),
  38. precision_(precision),
  39. input_name_(input_name),
  40. output_name_(output_name),
  41. is_initialized_(false)
  42. {
  43. runtime_ = UniquePtr<nvinfer1::IRuntime>(nvinfer1::createInferRuntime(logger_));
  44. }
  45. void TrtCommon::setup()
  46. {
  47. const fs::path path(model_file_path_);
  48. std::string extension = path.extension().string();
  49. if (fs::exists(path)) {
  50. if (extension == ".engine") {
  51. loadEngine(model_file_path_);
  52. } else if (extension == ".onnx") {
  53. fs::path cache_engine_path{model_file_path_};
  54. cache_engine_path.replace_extension("engine");
  55. if (fs::exists(cache_engine_path)) {
  56. loadEngine(cache_engine_path.string());
  57. } else {
  58. logger_.log(nvinfer1::ILogger::Severity::kINFO, "start build engine");
  59. buildEngineFromOnnx(model_file_path_, cache_engine_path.string());
  60. logger_.log(nvinfer1::ILogger::Severity::kINFO, "end build engine");
  61. }
  62. } else {
  63. is_initialized_ = false;
  64. return;
  65. }
  66. } else {
  67. is_initialized_ = false;
  68. return;
  69. }
  70. context_ = UniquePtr<nvinfer1::IExecutionContext>(engine_->createExecutionContext());
  71. #if (NV_TENSORRT_MAJOR * 10000) + (NV_TENSORRT_MINOR * 100) + NV_TENSOR_PATCH >= 80500
  72. input_dims_ = engine_->getTensorShape(input_name_.c_str());
  73. output_dims_ = engine_->getTensorShape(output_name_.c_str());
  74. #else
  75. // Deprecated since 8.5
  76. input_dims_ = engine_->getBindingDimensions(engine_->getBindingIndex(input_name_.c_str()));
  77. output_dims_ = engine_->getBindingDimensions(engine_->getBindingIndex(output_name_.c_str()));
  78. #endif
  79. is_initialized_ = true;
  80. }
  81. bool TrtCommon::loadEngine(std::string engine_file_path)
  82. {
  83. std::ifstream engine_file(engine_file_path);
  84. std::stringstream engine_buffer;
  85. engine_buffer << engine_file.rdbuf();
  86. std::string engine_str = engine_buffer.str();
  87. engine_ = UniquePtr<nvinfer1::ICudaEngine>(runtime_->deserializeCudaEngine(
  88. reinterpret_cast<const void *>(engine_str.data()), engine_str.size()));
  89. return true;
  90. }
  91. bool TrtCommon::buildEngineFromOnnx(std::string onnx_file_path, std::string output_engine_file_path)
  92. {
  93. auto builder = UniquePtr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(logger_));
  94. const auto explicitBatch =
  95. 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
  96. auto network = UniquePtr<nvinfer1::INetworkDefinition>(builder->createNetworkV2(explicitBatch));
  97. auto config = UniquePtr<nvinfer1::IBuilderConfig>(builder->createBuilderConfig());
  98. auto parser = UniquePtr<nvonnxparser::IParser>(nvonnxparser::createParser(*network, logger_));
  99. if (!parser->parseFromFile(
  100. onnx_file_path.c_str(), static_cast<int>(nvinfer1::ILogger::Severity::kERROR))) {
  101. return false;
  102. }
  103. #if (NV_TENSORRT_MAJOR * 1000) + (NV_TENSORRT_MINOR * 100) + NV_TENSOR_PATCH >= 8400
  104. config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, 16 << 20);
  105. #else
  106. config->setMaxWorkspaceSize(16 << 20);
  107. #endif
  108. if (precision_ == "fp16") {
  109. config->setFlag(nvinfer1::BuilderFlag::kFP16);
  110. } else if (precision_ == "int8") {
  111. config->setFlag(nvinfer1::BuilderFlag::kINT8);
  112. } else {
  113. return false;
  114. }
  115. auto plan = UniquePtr<nvinfer1::IHostMemory>(builder->buildSerializedNetwork(*network, *config));
  116. if (!plan) {
  117. return false;
  118. }
  119. engine_ =
  120. UniquePtr<nvinfer1::ICudaEngine>(runtime_->deserializeCudaEngine(plan->data(), plan->size()));
  121. if (!engine_) {
  122. return false;
  123. }
  124. // save engine
  125. std::ofstream file;
  126. file.open(output_engine_file_path, std::ios::binary | std::ios::out);
  127. if (!file.is_open()) {
  128. return false;
  129. }
  130. file.write((const char *)plan->data(), plan->size());
  131. file.close();
  132. return true;
  133. }
  134. bool TrtCommon::isInitialized() { return is_initialized_; }
  135. int TrtCommon::getNumInput()
  136. {
  137. return std::accumulate(
  138. input_dims_.d, input_dims_.d + input_dims_.nbDims, 1, std::multiplies<int>());
  139. }
  140. int TrtCommon::getNumOutput()
  141. {
  142. return std::accumulate(
  143. output_dims_.d, output_dims_.d + output_dims_.nbDims, 1, std::multiplies<int>());
  144. }
  145. } // namespace Tn