当前位置: 首页 > article >正文

work-stealing算法 ForkJoinPool

专栏系列文章地址:https://blog.csdn.net/qq_26437925/article/details/145290162

本文目标:

  1. 重点是通过例子程序理解work-stealing算法原理

目录

  • work-stealing算法
    • 算法原理和优缺点介绍
    • 使用场景
    • work-stealing例子代码
  • ForkJoinPool
    • new ForkJoinPool()
    • ForkJoinTask
    • 例子代码1
    • 例子代码2

work-stealing算法

算法原理和优缺点介绍

work-stealing算法是一种用于多线程并行计算中的任务调度算法。该算法的核心思想是允许空闲的线程从其他忙碌线程的工作队列中“窃取”任务来执行,以提高整体的资源利用率和系统的吞吐量。

在work-stealing算法中,每个线程通常都维护一个自己的任务队列,用于存储需要执行的任务。当某个线程完成自己的任务队列中的所有任务后,它会尝试从其他线程的任务队列中窃取任务。为了防止多个线程同时窃取同一个线程的任务,通常需要使用一些同步机制,如锁或原子操作等。

work-stealing算法的优点在于它能够动态地平衡负载,使得各个线程之间的任务分配更加均匀,从而提高了系统的并行效率和资源利用率。此外,该算法还具有较好的可扩展性和适应性,能够随着任务量的增加或减少而自动调整线程的数量和工作负载。

然而,work-stealing算法也存在一些挑战和限制。例如,在窃取任务时需要进行同步操作,这可能会增加一定的开销。此外,如果任务之间存在数据依赖关系,那么窃取任务可能会破坏这种依赖关系,从而导致错误的结果。因此,在使用work-stealing算法时需要根据具体的应用场景和任务特点进行权衡和选择。


介绍部分来自ai回答。可以简单理解如下:

一个大任务分割为若干个互不依赖的子任务,为了减少线程间的竞争,把这些子任务分别放到不同的队列里,并为每个队列创建一个单独的线程来执行队列里的任务,线程和队列一一对应。

线程1 - 队列1(任务1,任务2,任务3,…)
线程2 - 队列2(任务1,任务2,任务3,…)

比如线程1早早的把队列中任务都处理完了有空闲,但是队列2执行任务较慢;这样队列2中任务可以让线程1帮忙执行(即窃取线程1的任务)

使用场景

work-stealing算法主要应用于多线程并行计算场景,特别是在任务数量不确定、任务粒度不均匀或者负载容易波动的情况下。以下是一些具体的应用场景:

‌并行计算框架‌:work-stealing算法被广泛应用于各种并行计算框架中,如Intel TBB(Threading Building Blocks)、Cilk Plus以及Java的ForkJoinPool等。这些框架利用work-stealing算法来动态地平衡各个线程之间的负载,提高并行计算的效率。

‌大数据处理‌:在大数据处理领域,如Hadoop、Spark等分布式计算框架中,work-stealing算法可以用于优化任务调度。通过允许空闲节点窃取其他忙碌节点的任务,可以更加均衡地分配工作负载,提高整个集群的处理能力。

‌高性能计算‌:在高性能计算领域,work-stealing算法也被用于优化并行任务的调度。特别是在处理大规模科学计算和模拟仿真等任务时,work-stealing算法能够有效地平衡各个计算节点之间的负载,提高整体的计算效率。

‌实时系统‌:在实时系统中,任务的及时完成至关重要。work-stealing算法可以通过动态地调整任务分配,确保各个线程都能够及时完成任务,从而提高系统的实时性能。

‌云计算和虚拟化环境‌:在云计算和虚拟化环境中,资源的使用是动态的,并且负载容易波动。work-stealing算法可以用于优化虚拟机或容器之间的任务调度,确保资源的有效利用和负载均衡。

总之,work-stealing算法适用于各种需要高效利用多线程并行计算能力的场景。通过动态地平衡负载和提高资源利用率,它能够显著地提高系统的并行效率和整体性能。

work-stealing例子代码

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

class Task implements Runnable {
    private final int taskId;

    public Task(int taskId) {
        this.taskId = taskId;
    }

    @Override
    public void run() {
        try{
            TimeUnit.MILLISECONDS.sleep(300);
            // 线程0执行的任务慢一点,可以被窃取
            if (taskId % 4 == 0) {
                TimeUnit.MILLISECONDS.sleep(800);
            }
        }catch (Exception e){}
        System.out.println("Executing task " + taskId + " by " + Thread.currentThread().getName());
    }

    public int getTaskId() {
        return taskId;
    }
}

class WorkerThread extends Thread {
    private final BlockingQueue<Task> taskQueue;
    private final List<WorkerThread> allWorkers;
    private final AtomicInteger taskIdGenerator;
    private volatile boolean running = true;

    public WorkerThread(BlockingQueue<Task> taskQueue, List<WorkerThread> allWorkers, AtomicInteger taskIdGenerator) {
        this.taskQueue = taskQueue;
        this.allWorkers = allWorkers;
        this.taskIdGenerator = taskIdGenerator;
    }

    @Override
    public void run() {
        while (running) {
            Task task = null;
            try {
                // Try to retrieve a task from this worker's queue
                task = taskQueue.poll(100, TimeUnit.MILLISECONDS);
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                return;
            }

            if (task != null) {
                // Execute the retrieved task
                task.run();
            } else {
                // If no task is retrieved, try to steal a task from another worker
                task = stealTask();
                if (task != null) {
                    System.out.println("task " + task.getTaskId() + " stolen by " + Thread.currentThread().getName());
                    task.run();
                }
            }
        }
    }

    private Task stealTask() {
        List<WorkerThread> shuffledWorkers = new ArrayList<>(allWorkers);
        Collections.shuffle(shuffledWorkers);

        for (WorkerThread worker : shuffledWorkers) {
            if (worker != this && !worker.taskQueue.isEmpty()) {
                Task stolenTask = worker.taskQueue.poll();
                if (stolenTask != null) {
                    return stolenTask;
                }
            }
        }

        return null;
    }

    public void addTask(Task task) {
        taskQueue.offer(task);
    }

    public void stopTask() {
        running = false;
    }


    public static void main(String[] args) {
        int numWorkers = 4;
        BlockingQueue<Task> sharedQueue = new LinkedBlockingQueue<>();
        List<WorkerThread> workers = new ArrayList<>();
        AtomicInteger taskIdGenerator = new AtomicInteger(0);

        // Create and start worker threads
        for (int i = 0; i < numWorkers; i++) {
            BlockingQueue<Task> taskQueue = new LinkedBlockingQueue<>();
            WorkerThread worker = new WorkerThread(taskQueue, workers, taskIdGenerator);
            workers.add(worker);
            worker.start();
        }

        // Add tasks to the shared queue (or directly to worker queues for simplicity in this example)
        for (int i = 0; i < 20; i++) {
            int taskId = taskIdGenerator.incrementAndGet();
            int selectWorkId = taskId % numWorkers;
            Task task = new Task(taskId);
            workers.get(selectWorkId).addTask(task); // Distribute tasks round-robin for simplicity
        }

        // Let the workers run for some time before stopping them
        try {
            Thread.sleep(5000);
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }

        // Stop all workers
        for (WorkerThread worker : workers) {
            worker.stopTask();
        }

        // Wait for all workers to terminate (not strictly necessary in this example, but good practice)
        for (WorkerThread worker : workers) {
            try {
                worker.join();
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
        }
    }
}

‌注意‌:

  1. 在这个示例中,我们使用了BlockingQueue来实现任务队列,它支持线程安全的队列操作。
  2. WorkerThread类表示一个工作线程,它有自己的任务队列,并会尝试执行自己的任务或窃取其他线程的任务。
  3. Task类表示一个可执行的任务,它简单地打印出正在执行任务的线程名称和任务ID。
  4. 在main方法中,我们创建了指定数量的工作线程,并向它们分配了任务。然后,让工作线程运行一段时间后停止。

这个示例是一个简化的版本,实际的生产环境中可能需要更复杂的同步机制和错误处理。此外,为了简化示例,任务被直接添加到了工作线程的任务队列中,而不是使用一个共享的窃取队列。在实际实现中,可以考虑使用一个共享的窃取队列来优化任务窃取过程。

测试输出如下:
在这里插入图片描述

可以看到线程0的任务被其它线程窃取执行了。

ForkJoinPool

在这里插入图片描述

new ForkJoinPool()

public ForkJoinPool() {
        this(Math.min(MAX_CAP, Runtime.getRuntime().availableProcessors()),
             defaultForkJoinWorkerThreadFactory, null, false);
    }
/**
    * Creates a {@code ForkJoinPool} with the given parameters, without
    * any security checks or parameter validation.  Invoked directly by
    * makeCommonPool.
    */
private ForkJoinPool(int parallelism,
                        ForkJoinWorkerThreadFactory factory,
                        UncaughtExceptionHandler handler,
                        int mode,
                        String workerNamePrefix) {
    this.workerNamePrefix = workerNamePrefix;
    this.factory = factory;
    this.ueh = handler;
    this.config = (parallelism & SMASK) | mode;
    long np = (long)(-parallelism); // offset ctl counts
    this.ctl = ((np << AC_SHIFT) & AC_MASK) | ((np << TC_SHIFT) & TC_MASK);
}

默认线程数量是:Runtime.getRuntime().availableProcessors()

ForkJoinTask

public abstract class ForkJoinTask<V> implements Future<V>, Serializable {

Abstract base class for tasks that run within a ForkJoinPool. A ForkJoinTask is a thread-like entity that is much lighter weight than a normal thread. Huge numbers of tasks and subtasks may be hosted by a small number of actual threads in a ForkJoinPool, at the price of some usage limitations.

任务提交处理如下:
在这里插入图片描述

steal逻辑
在这里插入图片描述

例子代码1

import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.Future;
import java.util.concurrent.RecursiveTask;
import java.util.concurrent.TimeUnit;

public class ForkJoinTaskExample extends RecursiveTask<Integer> {

    public static final int threshold = 2;
    private int start;
    private int end;

    public ForkJoinTaskExample(int start, int end) {
        this.start = start;
        this.end = end;
    }

    @Override
    protected Integer compute() {
        int sum = 0;

        boolean canCompute = (end - start) <= threshold;
        if (canCompute) {
            for (int i = start; i <= end; i++) {
                sum += i;
            }
        } else {
            // 如果任务大于阈值,就分裂成两个子任务计算
            int middle = (start + end) / 2;
            ForkJoinTaskExample leftTask = new ForkJoinTaskExample(start, middle);
            ForkJoinTaskExample rightTask = new ForkJoinTaskExample(middle + 1, end);

            // 执行子任务
            leftTask.fork();
            rightTask.fork();
            // invokeAll(leftTask, rightTask);

            // 等待任务执行结束合并其结果
            int leftResult = leftTask.join();
            int rightResult = rightTask.join();

            // 合并子任务
            sum = leftResult + rightResult;
        }
        return sum;
    }

    static void testForkJoinPool() throws Exception{
        ForkJoinPool forkjoinPool = new ForkJoinPool();
        int sta = 1;
        int end = 100;
        //生成一个计算任务,计算连续区间范围的和
        ForkJoinTaskExample task = new ForkJoinTaskExample(sta, end);
        //执行一个任务
        Future<Integer> result = forkjoinPool.submit(task);
        System.out.println("result:" + result.get());
    }

    public static void main(String[] args) throws Exception{
        testForkJoinPool();

        TimeUnit.SECONDS.sleep(1);
    }
}

例子代码2

import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;


class MyTask implements Runnable {
    private final int taskId;

    public MyTask(int taskId) {
        this.taskId = taskId;
    }

    @Override
    public void run() {
        try{
            TimeUnit.MILLISECONDS.sleep(300);
            // 线程0执行的任务慢一点,可以被窃取
            if (taskId % 4 == 0) {
                TimeUnit.MILLISECONDS.sleep(800);
            }
        }catch (Exception e){}
        System.out.println("Executing task " + taskId + " by " + Thread.currentThread().getName());
    }

    public int getTaskId() {
        return taskId;
    }
}

public class Main {

    public static void main(String[] args) throws Exception {
        AtomicInteger taskIdGenerator = new AtomicInteger(0);

        ForkJoinPool forkjoinPool = new ForkJoinPool();

        // Add tasks to the shared queue (or directly to worker queues for simplicity in this example)
        for (int i = 0; i < 20; i++) {
            int taskId = taskIdGenerator.incrementAndGet();
            MyTask task = new MyTask(taskId);
            forkjoinPool.submit(task);
        }

        while (true){}
    }

}

可以看到有窃取行为
在这里插入图片描述


http://www.kler.cn/a/530149.html

相关文章:

  • VSCode插件Live Server
  • SpringCloud篇 微服务架构
  • 体系自适应的物联网漏洞挖掘系统摘要
  • HTB:Alert[WriteUP]
  • Kafka常见问题之 java.io.IOException: Disk error when trying to write to log
  • MySQL 如何深度分页问题
  • 【C语言】填空题/程序填空题1
  • 第三百六十节 JavaFX教程 - JavaFX 进度显示器
  • 2025-工具集合整理
  • 2025年2月2日(网络编程 tcp)
  • LeetCode:300.最长递增子序列
  • 【Rust自学】19.3. 高级函数和闭包
  • 【TCP协议】流量控制 滑动窗口 拥塞控制
  • 第二篇:多模态技术突破——DeepSeek如何重构AI的感知与认知边界
  • Spring应用场景 特性
  • 【C语言】自定义类型讲解
  • mysql字段名批量大小写转换
  • HarmonyOS NEXT:保存应用数据
  • 消息队列应用示例MessageQueues-STM32CubeMX-FreeRTOS《嵌入式系统设计》P343-P347
  • vector容器(详解)
  • 【Numpy核心编程攻略:Python数据处理、分析详解与科学计算】2.3 结构化索引:记录数组与字段访问
  • ExternalName Service 针对的是k8s集群外部有api服务的场景?
  • Haskell语言的多线程编程
  • [权限提升] Windows 提权 维持 — 系统错误配置提权 - Trusted Service Paths 提权
  • IM 即时通讯系统-43-简单的仿QQ聊天安卓APP
  • 2024 年 6 月大学英语四级考试真题(第 3 套)——纯享题目版