基于SpringBoot自定义线程池实现多线程执行方法,以及多线程之间的协调和同步
前言
在服务端开发中,多线程开发是非常重要的。因为多线程可以同时处理多个请求,从而提高应用程序的性能,大大改善用户体验。
一、先来了解三个问题
1.在SpringBoot项目中为啥需要自定义线程池?
(1)在SpringBoot项目中,通常会有很多异步的任务需要执行,比如发送邮件、短信、推送等。如果这些任务都直接在主线程中执行,会导致主线程被阻塞,影响用户的体验。因此,通常会使用线程池来管理这些异步任务,从而提高系统的性能和并发能力。
(2)SpringBoot默认提供了一个线程池,但是它的默认配置可能并不适合所有的应用场景。如果应用中的异步任务比较密集,可能会导致线程池中的线程不足,从而影响系统的性能。此时,就需要自行定义线程池,根据应用的实际情况来配置线程池的大小和其他参数,以达到最优的性能表现。
(3)另外,自行定义线程池还可以避免线程池满载时的任务被拒绝执行的问题,从而提高系统的稳定性。
2.java.util.concurrent.CountDownLatch这个类有啥作用?
(1)CountDownLatch 是 Java 中的一个同步工具类,用于协调多个线程之间的执行。它可以让某个线程等待直到倒计时器计数器为 0,然后再继续执行。
(2)CountDownLatch 的作用是,它可以让一个或多个线程等待其他线程执行完毕后再继续执行。在某些场景下,我们需要等待多个线程都执行完毕后才能进行下一步操作,这时候就可以使用 CountDownLatch。
(3)CountDownLatch 的使用方式是,首先创建一个计数器,然后在需要等待的线程中调用计数器的 countDown() 方法,每次调用会将计数器减 1。在需要等待的线程中调用 await() 方法,该方法会一直阻塞直到计数器为 0。当计数器为 0 时,所有等待的线程都会被唤醒,继续执行下一步操作。
(4)例如,我们可以在主线程中创建一个 CountDownLatch,然后将其传递给多个子线程,子线程在执行完任务后调用 countDown() 方法,主线程在需要等待子线程执行完毕后再继续执行时调用 await() 方法,这样就可以实现多个线程之间的协调和同步。
3.同一个类里面,for循环调用异步方法会被串行同步?
(1)在SpringBoot的自定义线程池中,同一个类里面,for循环调用异步方法会被串行同步执行的原因是因为异步方法默认使用的是调用线程的线程池,而在同一个类中,for循环中的所有异步方法都是由同一个调用线程调用的,因此它们会使用同一个线程池,导致它们被串行同步执行。
(2)要解决这个问题,可以在异步方法上添加@Async注解,并在调用异步方法的地方使用代理对象调用。这样每次调用异步方法时,都会使用新的线程池,避免了同一个线程池被多个异步方法共享的问题,从而实现并行执行。另外,为了避免for循环中的异步方法过多导致线程池资源耗尽,可以考虑使用线程池的拒绝策略来处理任务过多的情况。
二、示例代码
1.自定义线程池
(1)CommonThreadPoolConfig.java
package org.example.config;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.scheduling.annotation.EnableAsync;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import java.util.concurrent.Executor;
import java.util.concurrent.ThreadPoolExecutor;
/**
* 公共线程池配置
*/
@EnableAsync
@Configuration
public class CommonThreadPoolConfig {
@Bean("CommonThreadPoolExecutor")
public Executor syncExecutor() {
// 获取可用处理器的Java虚拟机的数量
int sum = Runtime.getRuntime().availableProcessors();
System.out.println("系统最大线程数 -> " + sum);
// 实例化自定义线程池
ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
// 设置线程池中的核心线程数(最小线程数)
// 线程池的核心线程数指的是线程池中一直存在的线程数量,即使它们没有任务可执行,处于空闲状态。
// 如果线程池中的线程数小于核心线程数,则创建新线程来处理任务,即使其他空闲线程可用。
// 如果线程池中的线程数已经等于核心线程数,那么新的任务就会被放入任务队列等待执行
executor.setCorePoolSize(16);
// 设置线程池中的最大线程数
// 如果线程池中的线程数已经达到了核心线程数,并且任务队列已满,则创建新线程来处理任务。
// 如果线程池中的线程数等于最大线程数,则任务将被拒绝。
executor.setMaxPoolSize(64);
// 设置线程池中任务队列的容量
// 线程池中的任务队列用于存储还未被执行的任务,当线程池中的线程已经全部被占用时,新的任务会被放入任务队列中等待执行。如果任务队列已满,那么新的任务就会被拒绝执行。
executor.setQueueCapacity(500);
// 设置线程池中空闲线程的存活时间
// 当线程池中的某个线程执行完任务后,如果当前线程池中的线程数大于核心线程数,那么这个空闲线程就会被放入线程池的等待队列中。
// 在等待队列中的空闲线程,如果在`keepAliveSeconds`时间内没有被再次使用,就会被回收销毁,以释放系统资源。如果`keepAliveSeconds`设置为0,则表示空闲线程立即被回收销毁。
executor.setKeepAliveSeconds(60);
// 设置线程池中线程的名称前缀
// 线程池中的每个线程都有一个唯一的名称,这个名称通常是由线程池的名称和线程的编号组成的。使用`setThreadNamePrefix()`方法可以在默认的线程名前面添加一个前缀,以便更好地区分不同的线程池。
executor.setThreadNamePrefix("async-");
// 设置线程池关闭时等待所有任务完成的时间。
// 当调用executor.shutdown()方法关闭线程池时,线程池会等待一段时间,如果在这段时间内所有任务都完成了,线程池会正常关闭;如果还有任务没有完成,线程池将强制关闭,未完成的任务将被丢弃。
executor.setAwaitTerminationSeconds(60);
// 设置线程池中任务队列已满时的拒绝策略,当线程池中的任务队列已满,而且线程池中的线程已经达到了最大线程数时,新的任务就无法被执行。这时就需要设置拒绝策略来处理这种情况。
// setRejectedExecutionHandler()方法提供了几种拒绝策略,包括:
// 1. AbortPolicy:直接抛出RejectedExecutionException异常,阻止系统正常运行。
// 2. CallerRunsPolicy:只要线程池未关闭,该策略直接在调用者线程中,运行当前被丢弃的任务。
// 3. DiscardOldestPolicy:丢弃队列里最老的一个任务,并执行当前任务。
// 4. DiscardPolicy:不处理,直接丢弃掉当前任务。
executor.setRejectedExecutionHandler(new ThreadPoolExecutor.DiscardPolicy());
// 设置线程池在关闭时是否等待所有任务完成
// 如果设置为`true`,则在调用`shutdown()`方法时,线程池会等待所有已提交的任务执行完毕后再关闭。
// 如果设置为`false`,则在调用`shutdown()`方法时,线程池会立即关闭,未执行的任务将被丢弃。
executor.setWaitForTasksToCompleteOnShutdown(true);
// 初始化线程池的配置
executor.initialize();
return executor;
}
}
2.控制层
(1)UserController.java
package org.example.controller;
import org.example.service.impl.UserServiceImpl;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.*;
@Controller
@RequestMapping(value = "api")
public class UserController {
@Autowired
private UserServiceImpl userService;
/**
* 基于线程池的异步接口
*/
@GetMapping(value = "threadPoolAsyncTest")
@ResponseBody
@CrossOrigin
public <T> T threadPoolAsyncTest () throws InterruptedException {
return userService.threadPoolAsyncTest();
}
/**
* 基于线程池的同步任务
*/
@GetMapping(value = "threadPoolSyncTasks")
@ResponseBody
@CrossOrigin
public <T> T threadPoolSyncTasks () {
return userService.threadPoolSyncTasks();
}
}
3.接口层
(1)IUserService.java
package org.example.service;
public interface IUserService {
<T> T threadPoolAsyncTest();
<T> T threadPoolSyncTasks();
}
(2)IAsyncService.java
package org.example.service;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Future;
public interface IAsyncService {
void sendSms(String mobile, String content);
void sendEmail(String email, String content);
Future<String> sendCode(String mobile) throws InterruptedException;
void syncTasks();
void asyncSaveTask(List<String> blockTaskList, CountDownLatch countDownLatch);
}
4.实现层
(1)UserServiceImpl.java
package org.example.service.impl;
import org.example.service.IAsyncService;
import org.example.service.IUserService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Primary;
import org.springframework.stereotype.Service;
import java.util.*;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Future;
@Primary
@Service
public class UserServiceImpl implements IUserService {
private static final Logger log = LoggerFactory.getLogger(UserServiceImpl.class);
@Autowired
private IAsyncService asyncService;
@Autowired
private CommonThreadPoolConfig commonThreadPoolConfig;
// commonThreadPoolConfig.destroy(); // 关闭自定义线程池
@Override
public <T> T threadPoolAsyncTest() {
try {
long startTime = System.currentTimeMillis();
String mobile = "13801380000";
String email = "123@abc.com";
String content = "你好,世界!";
asyncService.sendSms(mobile, content);
asyncService.sendEmail(email, content);
Future<String> future = asyncService.sendCode(mobile);
// String result = future.get(); // 阻塞获取结果
String result = "OK";
long endTime = System.currentTimeMillis();
log.info("Main cost {} ms, Future return 【{}】", endTime - startTime, result);
return (T) "success";
} catch (Exception e) {
return (T) "fail";
}
}
@Override
public <T> T threadPoolSyncTasks() {
long startTime = System.currentTimeMillis();
HashMap<String, Object> responseObj = new HashMap<>();
asyncService.syncTasks();
responseObj.put("code", 200);
responseObj.put("success", true);
responseObj.put("msg", "开始同步任务");
long endTime = System.currentTimeMillis();
log.info("threadPoolSyncTasks -> 线程名:{},运行时长:{} ms", Thread.currentThread().getName(), endTime - startTime); // 3 ms
return (T) responseObj;
}
/**
* 同步任务
*/
public void syncTasks() {
// 构建一个有100000个任务的列表 [Task-1 ~ Task-10000]
List<String> taskList = new ArrayList<>();
for (int i = 0; i < 10001; i++) {
taskList.add("Task-" + (i + 1));
}
// 每个区块可容纳1000条任务
int blockSize = 1000;
// 区块数量 10
int blockSum = taskList.size() % blockSize == 0 ? taskList.size() / blockSize : taskList.size() / blockSize + 1;
// 以区块分段的任务列表
List<List<String>> targetTaskList = new ArrayList<>();
for (int i = 0; i < blockSum - 1; i++) { // 10 - 1 = 9
// 0 [Task-1 ~ Task-1000]
// ...
// 8 [Task-8001 ~ Task-9000]
targetTaskList.add(i, taskList.subList(i * blockSize, blockSize * (i + 1)));
}
// 9 [Task-9001 ~ Task-10000]
targetTaskList.add(blockSum - 1, taskList.subList((blockSum - 1) * blockSize, taskList.size()));
System.out.println(targetTaskList);
this.batchSaveTask(targetTaskList);
}
/**
* 批量同步任务
*/
public void batchSaveTask(List<List<String>> targetTaskList) {
// 主线程中创建一个CountDownLatch计数器,数值为9,然后将其传递给多个子线程
CountDownLatch countDownLatch = new CountDownLatch(targetTaskList.size());
try {
long startTime = System.currentTimeMillis();
for (int i = 0; i < targetTaskList.size(); i++) {
List<String> blockTaskList = targetTaskList.get(i);
asyncService.asyncSaveTask(blockTaskList, countDownLatch);
}
// 主线程在需要等待子线程执行完毕后再继续执行时调用 await() 方法
countDownLatch.await();
long endTime = System.currentTimeMillis();
log.info("batchSaveTask -> 线程名:{},所有任务都执行完毕,运行时长:{} ms", Thread.currentThread().getName(), endTime - startTime); // 1016 ms
} catch (Exception e) {
e.printStackTrace();
}
}
}
(2)AsyncServiceImpl.java
package org.example.service.impl;
import org.example.service.IAsyncService;
import org.example.service.IUserService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Primary;
import org.springframework.scheduling.annotation.Async;
import org.springframework.scheduling.annotation.AsyncResult;
import org.springframework.stereotype.Service;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Future;
@Primary
@Service
public class AsyncServiceImpl implements IAsyncService {
private static final Logger log = LoggerFactory.getLogger(AsyncServiceImpl.class);
@Autowired
private IUserService userService;
/**
* 发送短信
*/
@Override
@Async(value = "CommonThreadPoolExecutor")
public void sendSms(String mobile, String content) {
try {
Thread.sleep(5000);
// xxxHandle.sendSms(mobile, content)
log.info("发送短信至 {} 成功,短信内容:{}", mobile, content);
} catch (Exception e) {
log.error("发送短信至 {} 失败,异常信息:{}", mobile, e);
}
}
/**
* 发送邮件
*/
@Override
@Async(value = "CommonThreadPoolExecutor")
public void sendEmail(String email, String content) {
try {
Thread.sleep(5000);
// xxxHandle.sendEmail(email, content)
log.info("发送邮件至 {} 成功,邮件内容:{}", email, content);
} catch (Exception e) {
log.error("发送邮件至 {} 失败,异常信息:{}", email, e);
}
}
/**
* 发送验证码
*/
@Override
@Async(value = "CommonThreadPoolExecutor")
public Future<String> sendCode(String mobile) throws InterruptedException {
Thread.sleep(3000);
log.info("尊敬的开发者,Thread: [{}], 为您服务...", Thread.currentThread().getName());
return new AsyncResult<>("发送验证码至 " + mobile + " 成功");
}
@Async("CommonThreadPoolExecutor")
public void syncTasks() {
userService.syncTasks();
}
@Async("CommonThreadPoolExecutor")
public void asyncSaveTask(List<String> tasks, CountDownLatch countDownLatch) {
try {
Thread.sleep(1000);
log.info("asyncSaveTask -> 线程名:{},保存数量为{}的任务成功", Thread.currentThread().getName(), tasks.size());
} catch (Exception e) {
e.printStackTrace();
} finally {
// 子线程在执行完任务后调用countDown()方法,将计数器减1
countDownLatch.countDown();
}
}
}