123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153 |
- #ifndef __CPM_HPP__
- #define __CPM_HPP__
- // Comsumer Producer Model
- #include <algorithm>
- #include <condition_variable>
- #include <future>
- #include <memory>
- #include <queue>
- #include <thread>
- namespace cpm {
- template <typename Result, typename Input, typename Model>
- class Instance {
- protected:
- struct Item {
- Input input;
- std::shared_ptr<std::promise<Result>> pro;
- };
- std::condition_variable cond_;
- std::queue<Item> input_queue_;
- std::mutex queue_lock_;
- std::shared_ptr<std::thread> worker_;
- volatile bool run_ = false;
- volatile int max_items_processed_ = 0;
- void *stream_ = nullptr;
- public:
- virtual ~Instance() { stop(); }
- void stop() {
- run_ = false;
- cond_.notify_one();
- {
- std::unique_lock<std::mutex> l(queue_lock_);
- while (!input_queue_.empty()) {
- auto &item = input_queue_.front();
- if (item.pro) item.pro->set_value(Result());
- input_queue_.pop();
- }
- };
- if (worker_) {
- worker_->join();
- worker_.reset();
- }
- }
- virtual std::shared_future<Result> commit(const Input &input) {
- Item item;
- item.input = input;
- item.pro.reset(new std::promise<Result>());
- {
- std::unique_lock<std::mutex> __lock_(queue_lock_);
- input_queue_.push(item);
- }
- cond_.notify_one();
- return item.pro->get_future();
- }
- virtual std::vector<std::shared_future<Result>> commits(const std::vector<Input> &inputs) {
- std::vector<std::shared_future<Result>> output;
- {
- std::unique_lock<std::mutex> __lock_(queue_lock_);
- for (int i = 0; i < (int)inputs.size(); ++i) {
- Item item;
- item.input = inputs[i];
- item.pro.reset(new std::promise<Result>());
- output.emplace_back(item.pro->get_future());
- input_queue_.push(item);
- }
- }
- cond_.notify_one();
- return output;
- }
- template <typename LoadMethod>
- bool start(const LoadMethod &loadmethod, int max_items_processed = 1, void *stream = nullptr) {
- stop();
- this->stream_ = stream;
- this->max_items_processed_ = max_items_processed;
- std::promise<bool> status;
- worker_ = std::make_shared<std::thread>(&Instance::worker<LoadMethod>, this,
- std::ref(loadmethod), std::ref(status));
- return status.get_future().get();
- }
- private:
- template <typename LoadMethod>
- void worker(const LoadMethod &loadmethod, std::promise<bool> &status) {
- std::shared_ptr<Model> model = loadmethod();
- if (model == nullptr) {
- status.set_value(false);
- return;
- }
- run_ = true;
- status.set_value(true);
- std::vector<Item> fetch_items;
- std::vector<Input> inputs;
- while (get_items_and_wait(fetch_items, max_items_processed_)) {
- inputs.resize(fetch_items.size());
- std::transform(fetch_items.begin(), fetch_items.end(), inputs.begin(),
- [](Item &item) { return item.input; });
- auto ret = model->forwards(inputs, stream_);
- for (int i = 0; i < (int)fetch_items.size(); ++i) {
- if (i < (int)ret.size()) {
- fetch_items[i].pro->set_value(ret[i]);
- } else {
- fetch_items[i].pro->set_value(Result());
- }
- }
- inputs.clear();
- fetch_items.clear();
- }
- model.reset();
- run_ = false;
- }
- virtual bool get_items_and_wait(std::vector<Item> &fetch_items, int max_size) {
- std::unique_lock<std::mutex> l(queue_lock_);
- cond_.wait(l, [&]() { return !run_ || !input_queue_.empty(); });
- if (!run_) return false;
- fetch_items.clear();
- for (int i = 0; i < max_size && !input_queue_.empty(); ++i) {
- fetch_items.emplace_back(std::move(input_queue_.front()));
- input_queue_.pop();
- }
- return true;
- }
- virtual bool get_item_and_wait(Item &fetch_item) {
- std::unique_lock<std::mutex> l(queue_lock_);
- cond_.wait(l, [&]() { return !run_ || !input_queue_.empty(); });
- if (!run_) return false;
- fetch_item = std::move(input_queue_.front());
- input_queue_.pop();
- return true;
- }
- };
- }; // namespace cpm
- #endif // __CPM_HPP__
|