使用C++实现一个高效的线程池
在多线程编程中,线程池是一种常见且高效的设计模式。它通过预先创建一定数量的线程来处理任务,从而避免频繁创建和销毁线程带来的性能开销。本文将详细介绍如何使用C++实现一个线程池,并解析相关代码实现细节。
线程池简介
线程池(Thread Pool)是一种管理和复用线程的机制。它通过维护一个线程集合,当有任务需要执行时,从线程池中分配一个空闲线程来处理任务,任务完成后线程归还到池中。这样可以显著减少线程创建和销毁的开销,提高系统的整体性能和响应速度。
设计思路
本文实现的线程池主要包含两个核心类:
- Thread类:封装了单个线程的创建、启动和管理。
- ThreadPool类:管理多个线程,维护任务队列,并调度任务给线程执行。
线程池支持两种模式:
- MODE_CACHED:缓存模式,根据任务量动态调整线程数量,适用于任务量不固定的场景。
- MODE_FIXED:固定模式,线程数量固定,适用于任务量稳定的场景。
Thread类实现
Thread
类负责封装单个线程的创建和管理。以下是Thread.h
和Thread.cpp
的实现。
Thread.h
#include <functional>
#include <atomic>
#include <cstdint>
#include <thread>
class Thread {
public:
using ThreadFunc = std::function<void(std::uint32_t)>;
public:
explicit Thread(ThreadFunc func);
void join();
~Thread();
void start();
[[nodiscard]] std::uint32_t getID() const;
[[nodiscard]] static std::uint32_t getNumCreated();
Thread(const Thread &) = delete;
Thread &operator=(const Thread &) = delete;
private:
ThreadFunc m_func;
uint32_t m_threadID;
std::thread m_thread;
static std::atomic<uint32_t> m_numCreateThread;
};
Thread.cpp
#include "Thread.h"
std::atomic<uint32_t> Thread::m_numCreateThread(0);
Thread::Thread(Thread::ThreadFunc func) : m_func(std::move(func)), m_threadID(m_numCreateThread.load()) {
m_numCreateThread++;
}
void Thread::start() {
m_thread = std::thread([this]() {
m_func(m_threadID);
});
m_thread.detach();
}
uint32_t Thread::getID() const {
return m_threadID;
}
uint32_t Thread::getNumCreated() {
return Thread::m_numCreateThread.load();
}
Thread::~Thread() {
join();
}
void Thread::join() {
if (m_thread.joinable()) {
m_thread.join();
}
}
解析
-
成员变量:
m_func
:线程执行的函数。m_threadID
:线程的唯一标识。m_thread
:std::thread
对象。m_numCreateThread
:静态原子变量,用于记录已创建的线程数量。
-
构造函数:
- 接受一个函数作为参数,并分配一个唯一的线程ID。
-
start方法:
- 启动线程,执行传入的函数,并将线程设为分离状态,以便在线程结束时自动回收资源。
-
join方法和析构函数:
- 如果线程可连接,则执行
join
操作,确保线程资源的正确回收。
- 如果线程可连接,则执行
ThreadPool类实现
ThreadPool
类负责管理多个线程,维护任务队列,并调度任务给线程执行。以下是ThreadPool.h
和ThreadPool.cpp
的实现。
ThreadPool.h
#include <mutex>
#include <unordered_map>
#include <memory>
#include <functional>
#include <queue>
#include <iostream>
#include <condition_variable>
#include <future>
#include <cstdint>
#include "Thread.h"
enum class THREAD_MODE {
MODE_CACHED,
MODE_FIXED,
};
class ThreadPool {
public:
explicit ThreadPool(THREAD_MODE mode = THREAD_MODE::MODE_CACHED, std::uint32_t maxThreadSize = 1024,
std::uint32_t initThreadSize = 4, std::uint32_t maxTaskSize = 1024);
~ThreadPool();
void setThreadMaxSize(uint32_t maxSize);
void setMode(THREAD_MODE mode);
void setTaskMaxSize(uint32_t maxSize);
void start(uint32_t taskSize = std::thread::hardware_concurrency());
ThreadPool(const ThreadPool &) = delete;
ThreadPool &operator=(const ThreadPool &) = delete;
template<typename Func, typename ...Args>
auto submitTask(Func &&func, Args &&...args) -> std::future<typename std::invoke_result<Func, Args...>::type>;
protected:
[[nodiscard]] bool checkState() const;
void ThreadFun(uint32_t threadID);
private:
using Task = std::function<void()>;
std::unordered_map<uint32_t, std::unique_ptr<Thread>> m_threads;
uint32_t m_initThreadSize; // 初始线程数量
std::atomic<std::uint32_t> m_spareThreadSize; // 空闲线程数量
uint32_t m_maxThreadSize; // 最大线程数量
std::atomic<bool> m_isRunning; // 线程池运行标志
THREAD_MODE m_mode; // 线程池运行模式
std::deque<Task> m_tasks;
std::atomic<uint32_t> m_taskSize;
uint32_t m_maxTaskSize;
uint32_t m_thread_maxSpareTime;
mutable std::mutex m_mutex; // 线程池互斥量
std::condition_variable m_notEmpty;
std::condition_variable m_notFull;
std::condition_variable m_isExit;
};
ThreadPool.cpp
#include "ThreadPool.hpp"
#include <thread>
ThreadPool::ThreadPool(THREAD_MODE mode, uint32_t maxThreadSize, uint32_t initThreadSize,
uint32_t maxTaskSize) : m_initThreadSize(initThreadSize), m_spareThreadSize(0),
m_maxThreadSize(maxThreadSize), m_isRunning(false),
m_mode(mode), m_taskSize(0), m_maxTaskSize(maxTaskSize),
m_thread_maxSpareTime(60) {
}
bool ThreadPool::checkState() const {
return m_isRunning;
}
void ThreadPool::setThreadMaxSize(uint32_t maxSize) {
if (checkState())
std::cerr << "threadPool is running, cannot change!" << std::endl;
else
this->m_maxThreadSize = maxSize;
}
void ThreadPool::setMode(THREAD_MODE mode) {
if (checkState())
std::cerr << "threadPool is running, cannot change!" << std::endl;
else
this->m_mode = mode;
}
void ThreadPool::setTaskMaxSize(uint32_t maxSize) {
if (checkState())
std::cerr << "threadPool is running, cannot change!" << std::endl;
else
this->m_maxTaskSize = maxSize;
}
void ThreadPool::ThreadFun(uint32_t threadID) {
auto last_time = std::chrono::high_resolution_clock::now();
for (;;) {
Task task;
{
std::unique_lock<std::mutex> lock(m_mutex);
std::cout << "threadID: " << threadID << " trying to get a task" << std::endl;
while (m_tasks.empty() && m_isRunning) {
if (m_mode == THREAD_MODE::MODE_CACHED && m_threads.size() > m_initThreadSize) {
if (m_notEmpty.wait_for(lock, std::chrono::seconds(3)) == std::cv_status::timeout) {
auto now_time = std::chrono::high_resolution_clock::now();
auto dur_time = std::chrono::duration_cast<std::chrono::seconds>(now_time - last_time);
if (dur_time.count() > m_thread_maxSpareTime && m_threads.size() > m_initThreadSize) {
m_threads.erase(threadID);
m_spareThreadSize--;
std::cout << "threadID: " << threadID << " exiting due to inactivity!" << std::endl;
return;
}
}
} else {
m_notEmpty.wait(lock);
}
}
if (!m_isRunning && m_tasks.empty()) {
m_threads.erase(threadID);
std::cout << "threadID: " << threadID << " exiting!" << std::endl;
m_isExit.notify_all();
return;
}
if (!m_tasks.empty()) {
m_spareThreadSize--;
task = std::move(m_tasks.front());
m_tasks.pop_front();
std::cout << "threadID: " << threadID << " retrieved a task!" << std::endl;
if (!m_tasks.empty())
m_notEmpty.notify_all();
m_notFull.notify_all();
}
}
if (task) {
try {
task();
} catch (const std::exception &e) {
std::cerr << "Exception in task: " << e.what() << std::endl;
} catch (...) {
std::cerr << "Unknown exception in task." << std::endl;
}
std::cout << "threadID: " << threadID << " completed a task." << std::endl;
m_spareThreadSize++;
last_time = std::chrono::high_resolution_clock::now();
}
}
}
void ThreadPool::start(std::uint32_t taskSize) {
m_isRunning = true;
for (std::uint32_t i = 0; i < taskSize; ++i) {
auto ptr = std::make_unique<Thread>(std::bind(&ThreadPool::ThreadFun, this, std::placeholders::_1));
auto threadID = ptr->getID();
m_threads.emplace(threadID, std::move(ptr));
}
for (auto &it: m_threads) {
it.second->start();
m_spareThreadSize++;
}
}
ThreadPool::~ThreadPool() {
m_isRunning = false;
std::unique_lock<std::mutex> lock(m_mutex);
m_notEmpty.notify_all();
m_notFull.notify_all();
m_isExit.wait(lock, [&]() -> bool { return m_threads.empty(); });
}
submitTask模板方法实现
template<typename Func, typename ...Args>
auto ThreadPool::submitTask(Func &&func, Args &&...args) -> std::future<typename std::invoke_result<Func, Args...>::type> {
using Rtype = typename std::invoke_result<Func, Args...>::type;
auto task = std::make_shared<std::packaged_task<Rtype()>>(
std::bind(std::forward<Func>(func), std::forward<Args>(args)...));
std::future<Rtype> result = task->get_future();
std::unique_lock lock(m_mutex);
if (!m_notFull.wait_for(lock, std::chrono::seconds(3),
[&]() -> bool { return m_tasks.size() < m_maxTaskSize; })) {
std::cerr << "Task queue is full, submit task failed!" << std::endl;
throw std::runtime_error("Task queue is full");
}
m_tasks.emplace_back([task] { (*task)(); });
m_notEmpty.notify_all();
if (m_mode == THREAD_MODE::MODE_CACHED && m_tasks.size() > m_spareThreadSize) {
if (m_threads.size() >= m_maxThreadSize) {
std::cerr << "Thread pool has reached max size, cannot create new thread!" << std::endl;
} else {
std::cout << "Creating a new thread!" << std::endl;
auto ptr = std::make_unique<Thread>(std::bind(&ThreadPool::ThreadFun, this, std::placeholders::_1));
u_int64_t threadID = ptr->getID();
m_threads.emplace(threadID, std::move(ptr));
m_threads[threadID]->start();
++m_spareThreadSize;
}
}
return result;
}
解析
-
成员变量:
m_threads
:存储所有线程的集合。m_tasks
:任务队列,存储待执行的任务。m_mutex
、m_notEmpty
、m_notFull
、m_isExit
:用于线程同步和任务调度的互斥量和条件变量。- 其他变量用于控制线程池的状态,如最大线程数、初始线程数、任务队列最大长度等。
-
构造函数:
- 初始化线程池的各项参数,如模式、最大线程数、初始线程数、最大任务数等。
-
start方法:
- 启动线程池,创建初始数量的线程,并将其启动。
-
submitTask模板方法:
- 提交任务到线程池,支持任意可调用对象。
- 使用
std::packaged_task
和std::future
实现任务的异步执行和结果获取。 - 如果任务队列已满,则在指定时间内等待,若仍满则抛出异常。
- 在缓存模式下,根据任务量动态创建新线程。
-
ThreadFun方法:
- 线程的工作函数,从任务队列中获取任务并执行。
- 在缓存模式下,线程在空闲一定时间后会自动退出,降低资源占用。
-
析构函数:
- 关闭线程池,通知所有线程退出,并等待所有线程结束。
线程池的使用
以下是一个简单的示例,展示如何使用上述实现的线程池。
#include "ThreadPool.h"
#include <iostream>
#include <chrono>
// 示例任务函数
void exampleTask(int n) {
std::cout << "Task " << n << " is starting." << std::endl;
std::this_thread::sleep_for(std::chrono::seconds(2));
std::cout << "Task " << n << " is completed." << std::endl;
}
int main() {
// 创建线程池,使用缓存模式,最大线程数为8,初始线程数为4,最大任务数为16
ThreadPool pool(THREAD_MODE::MODE_CACHED, 8, 4, 16);
pool.start();
// 提交多个任务
std::vector<std::future<void>> futures;
for (int i = 0; i < 10; ++i) {
futures.emplace_back(pool.submitTask(exampleTask, i));
}
// 等待所有任务完成
for (auto &fut : futures) {
fut.get();
}
std::cout << "All tasks have been completed." << std::endl;
return 0;
}
运行结果
threadID: 0 trying to get a task
threadID: 1 trying to get a task
threadID: 2 trying to get a task
threadID: 3 trying to get a task
Task 0 is starting.
Task 1 is starting.
Task 2 is starting.
Task 3 is starting.
threadID: 0 completed a task.
threadID: 0 trying to get a task
Task 4 is starting.
threadID: 1 completed a task.
threadID: 1 trying to get a task
Task 5 is starting.
...
All tasks have been completed.