// Copyright 2020 Tier IV, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #if (defined(_MSC_VER) or (defined(__GNUC__) and (7 <= __GNUC_MAJOR__))) #include namespace fs = ::std::filesystem; #else #include namespace fs = ::std::experimental::filesystem; #endif #include #include namespace Tn { void check_error(const ::cudaError_t e, decltype(__FILE__) f, decltype(__LINE__) n) { if (e != ::cudaSuccess) { std::stringstream s; s << ::cudaGetErrorName(e) << " (" << e << ")@" << f << "#L" << n << ": " << ::cudaGetErrorString(e); throw std::runtime_error{s.str()}; } } TrtCommon::TrtCommon( std::string model_path, std::string precision, std::string input_name, std::string output_name) : model_file_path_(model_path), precision_(precision), input_name_(input_name), output_name_(output_name), is_initialized_(false) { runtime_ = UniquePtr(nvinfer1::createInferRuntime(logger_)); } void TrtCommon::setup() { const fs::path path(model_file_path_); std::string extension = path.extension().string(); if (fs::exists(path)) { if (extension == ".engine") { loadEngine(model_file_path_); } else if (extension == ".onnx") { fs::path cache_engine_path{model_file_path_}; cache_engine_path.replace_extension("engine"); if (fs::exists(cache_engine_path)) { loadEngine(cache_engine_path.string()); } else { logger_.log(nvinfer1::ILogger::Severity::kINFO, "start build engine"); buildEngineFromOnnx(model_file_path_, cache_engine_path.string()); logger_.log(nvinfer1::ILogger::Severity::kINFO, "end build engine"); } } else { is_initialized_ = false; return; } } else { is_initialized_ = false; return; } context_ = UniquePtr(engine_->createExecutionContext()); #if (NV_TENSORRT_MAJOR * 10000) + (NV_TENSORRT_MINOR * 100) + NV_TENSOR_PATCH >= 80500 input_dims_ = engine_->getTensorShape(input_name_.c_str()); output_dims_ = engine_->getTensorShape(output_name_.c_str()); #else // Deprecated since 8.5 input_dims_ = engine_->getBindingDimensions(engine_->getBindingIndex(input_name_.c_str())); output_dims_ = engine_->getBindingDimensions(engine_->getBindingIndex(output_name_.c_str())); #endif is_initialized_ = true; } bool TrtCommon::loadEngine(std::string engine_file_path) { std::ifstream engine_file(engine_file_path); std::stringstream engine_buffer; engine_buffer << engine_file.rdbuf(); std::string engine_str = engine_buffer.str(); engine_ = UniquePtr(runtime_->deserializeCudaEngine( reinterpret_cast(engine_str.data()), engine_str.size())); return true; } bool TrtCommon::buildEngineFromOnnx(std::string onnx_file_path, std::string output_engine_file_path) { auto builder = UniquePtr(nvinfer1::createInferBuilder(logger_)); const auto explicitBatch = 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); auto network = UniquePtr(builder->createNetworkV2(explicitBatch)); auto config = UniquePtr(builder->createBuilderConfig()); auto parser = UniquePtr(nvonnxparser::createParser(*network, logger_)); if (!parser->parseFromFile( onnx_file_path.c_str(), static_cast(nvinfer1::ILogger::Severity::kERROR))) { return false; } #if (NV_TENSORRT_MAJOR * 1000) + (NV_TENSORRT_MINOR * 100) + NV_TENSOR_PATCH >= 8400 config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, 16 << 20); #else config->setMaxWorkspaceSize(16 << 20); #endif if (precision_ == "fp16") { config->setFlag(nvinfer1::BuilderFlag::kFP16); } else if (precision_ == "int8") { config->setFlag(nvinfer1::BuilderFlag::kINT8); } else { return false; } auto plan = UniquePtr(builder->buildSerializedNetwork(*network, *config)); if (!plan) { return false; } engine_ = UniquePtr(runtime_->deserializeCudaEngine(plan->data(), plan->size())); if (!engine_) { return false; } // save engine std::ofstream file; file.open(output_engine_file_path, std::ios::binary | std::ios::out); if (!file.is_open()) { return false; } file.write((const char *)plan->data(), plan->size()); file.close(); return true; } bool TrtCommon::isInitialized() { return is_initialized_; } int TrtCommon::getNumInput() { return std::accumulate( input_dims_.d, input_dims_.d + input_dims_.nbDims, 1, std::multiplies()); } int TrtCommon::getNumOutput() { return std::accumulate( output_dims_.d, output_dims_.d + output_dims_.nbDims, 1, std::multiplies()); } } // namespace Tn