cpm.hpp 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. #ifndef __CPM_HPP__
  2. #define __CPM_HPP__
  3. // Comsumer Producer Model
  4. #include <algorithm>
  5. #include <condition_variable>
  6. #include <future>
  7. #include <memory>
  8. #include <queue>
  9. #include <thread>
  10. namespace cpm {
  11. template <typename Result, typename Input, typename Model>
  12. class Instance {
  13. protected:
  14. struct Item {
  15. Input input;
  16. std::shared_ptr<std::promise<Result>> pro;
  17. };
  18. std::condition_variable cond_;
  19. std::queue<Item> input_queue_;
  20. std::mutex queue_lock_;
  21. std::shared_ptr<std::thread> worker_;
  22. volatile bool run_ = false;
  23. volatile int max_items_processed_ = 0;
  24. void *stream_ = nullptr;
  25. public:
  26. virtual ~Instance() { stop(); }
  27. void stop() {
  28. run_ = false;
  29. cond_.notify_one();
  30. {
  31. std::unique_lock<std::mutex> l(queue_lock_);
  32. while (!input_queue_.empty()) {
  33. auto &item = input_queue_.front();
  34. if (item.pro) item.pro->set_value(Result());
  35. input_queue_.pop();
  36. }
  37. };
  38. if (worker_) {
  39. worker_->join();
  40. worker_.reset();
  41. }
  42. }
  43. virtual std::shared_future<Result> commit(const Input &input) {
  44. Item item;
  45. item.input = input;
  46. item.pro.reset(new std::promise<Result>());
  47. {
  48. std::unique_lock<std::mutex> __lock_(queue_lock_);
  49. input_queue_.push(item);
  50. }
  51. cond_.notify_one();
  52. return item.pro->get_future();
  53. }
  54. virtual std::vector<std::shared_future<Result>> commits(const std::vector<Input> &inputs) {
  55. std::vector<std::shared_future<Result>> output;
  56. {
  57. std::unique_lock<std::mutex> __lock_(queue_lock_);
  58. for (int i = 0; i < (int)inputs.size(); ++i) {
  59. Item item;
  60. item.input = inputs[i];
  61. item.pro.reset(new std::promise<Result>());
  62. output.emplace_back(item.pro->get_future());
  63. input_queue_.push(item);
  64. }
  65. }
  66. cond_.notify_one();
  67. return output;
  68. }
  69. template <typename LoadMethod>
  70. bool start(const LoadMethod &loadmethod, int max_items_processed = 1, void *stream = nullptr) {
  71. stop();
  72. this->stream_ = stream;
  73. this->max_items_processed_ = max_items_processed;
  74. std::promise<bool> status;
  75. worker_ = std::make_shared<std::thread>(&Instance::worker<LoadMethod>, this,
  76. std::ref(loadmethod), std::ref(status));
  77. return status.get_future().get();
  78. }
  79. private:
  80. template <typename LoadMethod>
  81. void worker(const LoadMethod &loadmethod, std::promise<bool> &status) {
  82. std::shared_ptr<Model> model = loadmethod();
  83. if (model == nullptr) {
  84. status.set_value(false);
  85. return;
  86. }
  87. run_ = true;
  88. status.set_value(true);
  89. std::vector<Item> fetch_items;
  90. std::vector<Input> inputs;
  91. while (get_items_and_wait(fetch_items, max_items_processed_)) {
  92. inputs.resize(fetch_items.size());
  93. std::transform(fetch_items.begin(), fetch_items.end(), inputs.begin(),
  94. [](Item &item) { return item.input; });
  95. auto ret = model->forwards(inputs, stream_);
  96. for (int i = 0; i < (int)fetch_items.size(); ++i) {
  97. if (i < (int)ret.size()) {
  98. fetch_items[i].pro->set_value(ret[i]);
  99. } else {
  100. fetch_items[i].pro->set_value(Result());
  101. }
  102. }
  103. inputs.clear();
  104. fetch_items.clear();
  105. }
  106. model.reset();
  107. run_ = false;
  108. }
  109. virtual bool get_items_and_wait(std::vector<Item> &fetch_items, int max_size) {
  110. std::unique_lock<std::mutex> l(queue_lock_);
  111. cond_.wait(l, [&]() { return !run_ || !input_queue_.empty(); });
  112. if (!run_) return false;
  113. fetch_items.clear();
  114. for (int i = 0; i < max_size && !input_queue_.empty(); ++i) {
  115. fetch_items.emplace_back(std::move(input_queue_.front()));
  116. input_queue_.pop();
  117. }
  118. return true;
  119. }
  120. virtual bool get_item_and_wait(Item &fetch_item) {
  121. std::unique_lock<std::mutex> l(queue_lock_);
  122. cond_.wait(l, [&]() { return !run_ || !input_queue_.empty(); });
  123. if (!run_) return false;
  124. fetch_item = std::move(input_queue_.front());
  125. input_queue_.pop();
  126. return true;
  127. }
  128. };
  129. }; // namespace cpm
  130. #endif // __CPM_HPP__