C++簡易執行緒池

kewei168發表於2020-11-27

最近作業剛好用到了多執行緒的內容,又重新寫了一遍執行緒池,加深了對其的理解。這裡基於C++11的thread來實現一個簡單通用的執行緒池,基本思路是,建構函式裡面建立一定數量的執行緒,所有執行緒共享一個任務佇列,每個執行緒進入一個“死”迴圈,監聽任務佇列,一旦來了新的任務,則喚醒一個執行緒執行任務。

實現

執行緒池有幾個關鍵的變數:

  • std::vector<std::thread> threads; — 儲存所有的執行緒例項,用於解構函式時候銷燬
  • std::queue<std::function<void(void)> > tasks; — 共享的任務佇列,最好使用queue或者list等,頭尾操作(如:push,pop)都是 O ( 1 ) O(1) O(1)
  • std::mutex mtx; — 全域性鎖,用於保護對於共享任務佇列的訪問
  • std::condition_variable cv; — 條件變數,用於喚醒執行緒

幾個關鍵函式詳解:

(1)新增任務

外部函式通過該函式新增任務到執行緒池內部,注意由於任務佇列是所有執行緒共享的,所以這裡新增任務之前,需要先上鎖保證執行緒安全。該函式的最後一行this->cv.notify_one();,就是隨機喚醒一個閒置執行緒來執行該任務(如果執行緒都在忙,則最先閒置下來的執行緒執行該任務)。

void addTask(std::function<void(void)> task) {
    std::unique_lock<std::mutex> lck(mtx);
    this->tasks.push(task);
    this->numTaskRemaining++;
    // envoke a thread to do the task
    this->cv.notify_one();
}

(2)執行緒任務

每個執行緒被建立後,就會執行該函式。函式一開始cv.wait,等待新的任務,或者執行緒池被銷燬;一旦有新的任務來,則從任務佇列裡面獲取一個任務,然後釋放鎖。因為在任務執行過程中,不涉及任何race condition,並且我們並不知道任務執行的時長(可能會很長),所以我們應該先釋放鎖,讓其他執行緒可以訪問共享變數。等待任務執行結束後,重新上鎖,然後修改剩餘任務數量。

void doTask() {
      while (true) {
        std::unique_lock<std::mutex> lck(this->mtx);
        // use a conditional variable to wait
        this->cv.wait(lck, [this] {
          // already in the critical section, so can access these variables safely
          return !this->tasks.empty() || this->stop;
        });
        if (this->stop) {
          return;
        }
        // fetch a task
        std::function<void(void)> task = std::move(this->tasks.front());
        this->tasks.pop();
        lck.unlock();
        // no need to lock while doing the task
        task();
        // lock again to update the remaing tasks variable
        lck.lock();
        this->numTaskRemaining--;
        // notify the waitAll()
        cv_finished.notify_one();
      }
}

(3)等待所有任務執行完畢

這裡我額外實現了一個waitAll函式,可以等待任務佇列裡面所有的任務被執行完。注意這裡是執行完,而不是佇列為空,這兩者不是一個概念。舉個例子:任務佇列裡面只剩下最後兩個任務,然後此時有兩個執行緒都是空閒的,他們分別獲取並執行一個任務,假設執行緒A執行任務A需要10s,執行緒B執行任務B需要1s;那麼1s後執行緒B執行完任務,呼叫了cv_finished.notify_one();,此時任務佇列已經是空,但是所有任務並沒有全部執行完成(注意任務A還需要9s)。所以這裡我用的是this->numTaskRemaining == 0;而不是this->tasks.empty();

void waitAll() {
    std::unique_lock<std::mutex> lck(mtx);
    this->cv_finished.wait(lck, [this] { return this->numTaskRemaining == 0; });
}

將上述程式碼結合起來,就得到了最終的完整版程式碼。

class ThreadPool {
   private:
    std::vector<std::thread> threads;
    std::queue<std::function<void(void)> > tasks;
    // global mutex, use to protext the task queue
    std::mutex mtx;
    std::condition_variable cv;
    std::condition_variable cv_finished;
    bool stop;
    size_t numTaskRemaining;

    void doTask() {
      while (true) {
        std::unique_lock<std::mutex> lck(this->mtx);
        // use a conditional variable to wait
        this->cv.wait(lck, [this] {
          // already in the critical section, so can access these variables safely
          return !this->tasks.empty() || this->stop;
        });
        if (this->stop) {
          return;
        }
        // fetch a task
        std::function<void(void)> task = std::move(this->tasks.front());
        this->tasks.pop();
        lck.unlock();
        // no need to lock while doing the task
        task();
        // lock again to update the remaing tasks variable
        lck.lock();
        this->numTaskRemaining--;
        // notify the waitAll()
        cv_finished.notify_one();
      }
    }

   public:
    ThreadPool(int cnt) : stop(false), numTaskRemaining(0) {
      // initialize the threadpool
      for (int i = 0; i < cnt; i++) {
        threads.push_back(std::thread([this] { doTask(); }));
      }
    }

    ~ThreadPool() {
      // first finish all remaining tasks
      waitAll();
      std::unique_lock<std::mutex> lck(mtx);
      this->stop = true;
      // notify all thread to finish
      this->cv.notify_all();
      lck.unlock();
      for (auto & th : threads) {
        if (th.joinable()) {
          th.join();
        }
      }
    }

    void addTask(std::function<void(void)> task) {
      std::unique_lock<std::mutex> lck(mtx);
      this->tasks.push(task);
      this->numTaskRemaining++;
      // envoke a thread to do the task
      this->cv.notify_one();
    }

    // This function will notify the threadpool to run all task until the queue is empty;
    void waitAll() {
      std::unique_lock<std::mutex> lck(mtx);
      this->cv_finished.wait(lck, [this] { return this->numTaskRemaining == 0; });
    }
  };

使用

注意我們的執行緒池接受的任務型別是function<void(void)>,就是沒有引數沒有返回值的一個函式,你可能會覺得這個限制很大,但其實我們可以用另一種方式將有參函式轉換為無參函式,那就是c++11退出的lamda函式。

新建一個main.cpp, 寫入如下程式碼:

#include <unistd.h>
#include <iostream>
#include "threadpool.hpp"

void printHello(int taskID, std::thread::id threadID) {
  std::cout << "This is task " << taskID << ", running on thread " << threadID << "\n";
}

int main(int argc, char * argv[]) {
  ECE565::ThreadPool tp(2);

  int num = 1;
  tp.addTask([=] {
    sleep(1);
    printHello(num, std::this_thread::get_id());
  });

  num = 2;
  tp.addTask([=] {
    sleep(3);
    printHello(num, std::this_thread::get_id());
  });

  tp.waitAll();
  return EXIT_SUCCESS;
}

我們可以用[=]{ // do anything you want }這樣的形式,將有參函式轉換為無參函式。其中等於號代表“捕獲”當前所有的變數(值傳遞),可以換成&從而變為引用傳遞,也可以指定“捕獲”具體的變數。

編譯時新增-pthread flag,並指明c++11,如g++ -std=gnu++11 -pthread -o main main.cpp threadpool.hpp,執行即可看到結果。

This is task 1, running on thread 140593016403712
This is task 2, running on thread 140593008011008

相關文章