work-stealing算法 ForkJoinPool
专栏系列文章地址:https://blog.csdn.net/qq_26437925/article/details/145290162
本文目标:
- 重点是通过例子程序理解work-stealing算法原理
目录
- work-stealing算法
- 算法原理和优缺点介绍
- 使用场景
- work-stealing例子代码
- ForkJoinPool
- new ForkJoinPool()
- ForkJoinTask
- 例子代码1
- 例子代码2
work-stealing算法
算法原理和优缺点介绍
work-stealing算法是一种用于多线程并行计算中的任务调度算法。该算法的核心思想是允许空闲的线程从其他忙碌线程的工作队列中“窃取”任务来执行,以提高整体的资源利用率和系统的吞吐量。
在work-stealing算法中,每个线程通常都维护一个自己的任务队列,用于存储需要执行的任务。当某个线程完成自己的任务队列中的所有任务后,它会尝试从其他线程的任务队列中窃取任务。为了防止多个线程同时窃取同一个线程的任务,通常需要使用一些同步机制,如锁或原子操作等。
work-stealing算法的优点在于它能够动态地平衡负载,使得各个线程之间的任务分配更加均匀,从而提高了系统的并行效率和资源利用率。此外,该算法还具有较好的可扩展性和适应性,能够随着任务量的增加或减少而自动调整线程的数量和工作负载。
然而,work-stealing算法也存在一些挑战和限制。例如,在窃取任务时需要进行同步操作,这可能会增加一定的开销。此外,如果任务之间存在数据依赖关系,那么窃取任务可能会破坏这种依赖关系,从而导致错误的结果。因此,在使用work-stealing算法时需要根据具体的应用场景和任务特点进行权衡和选择。
介绍部分来自ai回答。可以简单理解如下:
一个大任务分割为若干个互不依赖的子任务,为了减少线程间的竞争,把这些子任务分别放到不同的队列里,并为每个队列创建一个单独的线程来执行队列里的任务,线程和队列一一对应。
线程1 - 队列1(任务1,任务2,任务3,…)
线程2 - 队列2(任务1,任务2,任务3,…)
比如线程1早早的把队列中任务都处理完了有空闲,但是队列2执行任务较慢;这样队列2中任务可以让线程1帮忙执行(即窃取
线程1的任务)
使用场景
work-stealing算法主要应用于多线程并行计算场景,特别是在任务数量不确定、任务粒度不均匀或者负载容易波动的情况下。以下是一些具体的应用场景:
并行计算框架:work-stealing算法被广泛应用于各种并行计算框架中,如Intel TBB(Threading Building Blocks)、Cilk Plus以及Java的ForkJoinPool等。这些框架利用work-stealing算法来动态地平衡各个线程之间的负载,提高并行计算的效率。
大数据处理:在大数据处理领域,如Hadoop、Spark等分布式计算框架中,work-stealing算法可以用于优化任务调度。通过允许空闲节点窃取其他忙碌节点的任务,可以更加均衡地分配工作负载,提高整个集群的处理能力。
高性能计算:在高性能计算领域,work-stealing算法也被用于优化并行任务的调度。特别是在处理大规模科学计算和模拟仿真等任务时,work-stealing算法能够有效地平衡各个计算节点之间的负载,提高整体的计算效率。
实时系统:在实时系统中,任务的及时完成至关重要。work-stealing算法可以通过动态地调整任务分配,确保各个线程都能够及时完成任务,从而提高系统的实时性能。
云计算和虚拟化环境:在云计算和虚拟化环境中,资源的使用是动态的,并且负载容易波动。work-stealing算法可以用于优化虚拟机或容器之间的任务调度,确保资源的有效利用和负载均衡。
总之,work-stealing算法适用于各种需要高效利用多线程并行计算能力的场景。通过动态地平衡负载和提高资源利用率,它能够显著地提高系统的并行效率和整体性能。
work-stealing例子代码
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
class Task implements Runnable {
private final int taskId;
public Task(int taskId) {
this.taskId = taskId;
}
@Override
public void run() {
try{
TimeUnit.MILLISECONDS.sleep(300);
// 线程0执行的任务慢一点,可以被窃取
if (taskId % 4 == 0) {
TimeUnit.MILLISECONDS.sleep(800);
}
}catch (Exception e){}
System.out.println("Executing task " + taskId + " by " + Thread.currentThread().getName());
}
public int getTaskId() {
return taskId;
}
}
class WorkerThread extends Thread {
private final BlockingQueue<Task> taskQueue;
private final List<WorkerThread> allWorkers;
private final AtomicInteger taskIdGenerator;
private volatile boolean running = true;
public WorkerThread(BlockingQueue<Task> taskQueue, List<WorkerThread> allWorkers, AtomicInteger taskIdGenerator) {
this.taskQueue = taskQueue;
this.allWorkers = allWorkers;
this.taskIdGenerator = taskIdGenerator;
}
@Override
public void run() {
while (running) {
Task task = null;
try {
// Try to retrieve a task from this worker's queue
task = taskQueue.poll(100, TimeUnit.MILLISECONDS);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
return;
}
if (task != null) {
// Execute the retrieved task
task.run();
} else {
// If no task is retrieved, try to steal a task from another worker
task = stealTask();
if (task != null) {
System.out.println("task " + task.getTaskId() + " stolen by " + Thread.currentThread().getName());
task.run();
}
}
}
}
private Task stealTask() {
List<WorkerThread> shuffledWorkers = new ArrayList<>(allWorkers);
Collections.shuffle(shuffledWorkers);
for (WorkerThread worker : shuffledWorkers) {
if (worker != this && !worker.taskQueue.isEmpty()) {
Task stolenTask = worker.taskQueue.poll();
if (stolenTask != null) {
return stolenTask;
}
}
}
return null;
}
public void addTask(Task task) {
taskQueue.offer(task);
}
public void stopTask() {
running = false;
}
public static void main(String[] args) {
int numWorkers = 4;
BlockingQueue<Task> sharedQueue = new LinkedBlockingQueue<>();
List<WorkerThread> workers = new ArrayList<>();
AtomicInteger taskIdGenerator = new AtomicInteger(0);
// Create and start worker threads
for (int i = 0; i < numWorkers; i++) {
BlockingQueue<Task> taskQueue = new LinkedBlockingQueue<>();
WorkerThread worker = new WorkerThread(taskQueue, workers, taskIdGenerator);
workers.add(worker);
worker.start();
}
// Add tasks to the shared queue (or directly to worker queues for simplicity in this example)
for (int i = 0; i < 20; i++) {
int taskId = taskIdGenerator.incrementAndGet();
int selectWorkId = taskId % numWorkers;
Task task = new Task(taskId);
workers.get(selectWorkId).addTask(task); // Distribute tasks round-robin for simplicity
}
// Let the workers run for some time before stopping them
try {
Thread.sleep(5000);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
// Stop all workers
for (WorkerThread worker : workers) {
worker.stopTask();
}
// Wait for all workers to terminate (not strictly necessary in this example, but good practice)
for (WorkerThread worker : workers) {
try {
worker.join();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
}
}
注意:
- 在这个示例中,我们使用了BlockingQueue来实现任务队列,它支持线程安全的队列操作。
- WorkerThread类表示一个工作线程,它有自己的任务队列,并会尝试执行自己的任务或窃取其他线程的任务。
- Task类表示一个可执行的任务,它简单地打印出正在执行任务的线程名称和任务ID。
- 在main方法中,我们创建了指定数量的工作线程,并向它们分配了任务。然后,让工作线程运行一段时间后停止。
这个示例是一个简化的版本,实际的生产环境中可能需要更复杂的同步机制和错误处理。此外,为了简化示例,任务被直接添加到了工作线程的任务队列中,而不是使用一个共享的窃取队列。在实际实现中,可以考虑使用一个共享的窃取队列来优化任务窃取过程。
测试输出如下:
可以看到线程0的任务被其它线程窃取执行了。
ForkJoinPool
new ForkJoinPool()
public ForkJoinPool() {
this(Math.min(MAX_CAP, Runtime.getRuntime().availableProcessors()),
defaultForkJoinWorkerThreadFactory, null, false);
}
/**
* Creates a {@code ForkJoinPool} with the given parameters, without
* any security checks or parameter validation. Invoked directly by
* makeCommonPool.
*/
private ForkJoinPool(int parallelism,
ForkJoinWorkerThreadFactory factory,
UncaughtExceptionHandler handler,
int mode,
String workerNamePrefix) {
this.workerNamePrefix = workerNamePrefix;
this.factory = factory;
this.ueh = handler;
this.config = (parallelism & SMASK) | mode;
long np = (long)(-parallelism); // offset ctl counts
this.ctl = ((np << AC_SHIFT) & AC_MASK) | ((np << TC_SHIFT) & TC_MASK);
}
默认线程数量是:Runtime.getRuntime().availableProcessors()
ForkJoinTask
public abstract class ForkJoinTask<V> implements Future<V>, Serializable {
Abstract base class for tasks that run within a ForkJoinPool
. A ForkJoinTask
is a thread-like entity that is much lighter weight than a normal thread. Huge numbers of tasks and subtasks may be hosted by a small number of actual threads in a ForkJoinPool, at the price of some usage limitations.
任务提交处理如下:
steal逻辑
例子代码1
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.Future;
import java.util.concurrent.RecursiveTask;
import java.util.concurrent.TimeUnit;
public class ForkJoinTaskExample extends RecursiveTask<Integer> {
public static final int threshold = 2;
private int start;
private int end;
public ForkJoinTaskExample(int start, int end) {
this.start = start;
this.end = end;
}
@Override
protected Integer compute() {
int sum = 0;
boolean canCompute = (end - start) <= threshold;
if (canCompute) {
for (int i = start; i <= end; i++) {
sum += i;
}
} else {
// 如果任务大于阈值,就分裂成两个子任务计算
int middle = (start + end) / 2;
ForkJoinTaskExample leftTask = new ForkJoinTaskExample(start, middle);
ForkJoinTaskExample rightTask = new ForkJoinTaskExample(middle + 1, end);
// 执行子任务
leftTask.fork();
rightTask.fork();
// invokeAll(leftTask, rightTask);
// 等待任务执行结束合并其结果
int leftResult = leftTask.join();
int rightResult = rightTask.join();
// 合并子任务
sum = leftResult + rightResult;
}
return sum;
}
static void testForkJoinPool() throws Exception{
ForkJoinPool forkjoinPool = new ForkJoinPool();
int sta = 1;
int end = 100;
//生成一个计算任务,计算连续区间范围的和
ForkJoinTaskExample task = new ForkJoinTaskExample(sta, end);
//执行一个任务
Future<Integer> result = forkjoinPool.submit(task);
System.out.println("result:" + result.get());
}
public static void main(String[] args) throws Exception{
testForkJoinPool();
TimeUnit.SECONDS.sleep(1);
}
}
例子代码2
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
class MyTask implements Runnable {
private final int taskId;
public MyTask(int taskId) {
this.taskId = taskId;
}
@Override
public void run() {
try{
TimeUnit.MILLISECONDS.sleep(300);
// 线程0执行的任务慢一点,可以被窃取
if (taskId % 4 == 0) {
TimeUnit.MILLISECONDS.sleep(800);
}
}catch (Exception e){}
System.out.println("Executing task " + taskId + " by " + Thread.currentThread().getName());
}
public int getTaskId() {
return taskId;
}
}
public class Main {
public static void main(String[] args) throws Exception {
AtomicInteger taskIdGenerator = new AtomicInteger(0);
ForkJoinPool forkjoinPool = new ForkJoinPool();
// Add tasks to the shared queue (or directly to worker queues for simplicity in this example)
for (int i = 0; i < 20; i++) {
int taskId = taskIdGenerator.incrementAndGet();
MyTask task = new MyTask(taskId);
forkjoinPool.submit(task);
}
while (true){}
}
}
可以看到有窃取行为