【防止重复提交】Redis + AOP + 注解的方式实现分布式锁
文章目录
- 工作原理
- 需求实现
- 1)自定义防重复提交注解
- 2)定义防重复提交AOP切面
- 3)RedisLock 工具类
- 4)过滤器 + 请求工具类
- 5)测试Controller
- 6)测试结果
工作原理
分布式环境下,可能会遇到用户对某个接口被重复点击的场景,为了防止接口重复提交造成的问题,可用 Redis 实现一个简单的分布式锁来解决问题。
在 Redis 中, SETNX
命令是可以帮助我们实现互斥。SETNX 即 SET if Not eXists (对应 Java 中的 setIfAbsent
方法),如果 key 不存在的话,才会设置 key 的值。如果 key 已经存在, SETNX 啥也不做。
需求实现
- 自定义一个防止重复提交的注解,注解中可以携带到期时间和一个参数的key
- 为需要防止重复提交的接口添加注解
- 注解AOP会拦截加了此注解的请求,进行加解锁处理并且添加注解上设置的key超时时间
- Redis 中的
key = token + "-" + path + "-" + param_value;
(例如:17800000001 + /api/subscribe/ + zhangsan) - 如果重复调用某个加了注解的接口且key还未到期,就会返回重复提交的Result。
1)自定义防重复提交注解
自定义防止重复提交注解,注解中可设置 超时时间 + 要扫描的参数(请求中的某个参数,最终拼接后成为Redis中的key)
package com.lihw.lihwtestboot.noRepeatSubmit;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* 防重复提交注解
*/
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface NoRepeatSubmit {
/**
* 锁过期的时间
*/
int seconds() default 5;
/**
* 要扫描的参数
*/
String scanParam() default "";
}
2)定义防重复提交AOP切面
@Pointcut("@annotation(noRepeatSubmit)")
表示切点表达式,它使用了注解匹配的方式来选择被注解 @NoRepeatSubmit
标记的方法。
package com.lihw.lihwtestboot.noRepeatSubmit;
import com.alibaba.fastjson.JSONObject;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.util.Assert;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import javax.servlet.http.HttpServletRequest;
import java.io.BufferedReader;
import java.io.IOException;
import java.util.UUID;
/**
* 重复提交aop
*/
@Aspect
@Component
public class RepeatSubmitAspect {
private static final Logger LOGGER = LoggerFactory.getLogger(RepeatSubmitAspect.class);
@Autowired
private RedisLock redisLock;
@Pointcut("@annotation(noRepeatSubmit)")
public void pointCut(NoRepeatSubmit noRepeatSubmit) {
}
@Around("pointCut(noRepeatSubmit)")
public Object around(ProceedingJoinPoint pjp, NoRepeatSubmit noRepeatSubmit) throws Throwable {
//获取基本信息
ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
HttpServletRequest request = attributes.getRequest();
Assert.notNull(request, "request can not null");
int lockSeconds = noRepeatSubmit.seconds();//过期时间
String threadName = Thread.currentThread().getName();// 获取当前线程名称
String param = noRepeatSubmit.scanParam();//请求参数
String path = request.getServletPath();
String type = request.getMethod();
String param_value = "";
if (type.equals("POST")){
param_value = JSONObject.parseObject(new BodyReaderHttpServletRequestWrapper(request).getBodyString()).getString(param);
}else if (type.equals("GET")){
param_value = request.getParameter(param);
}
String token = request.getHeader("uid");
LOGGER.info("线程:{}, 接口:{},重复提交验证",threadName,path);
String key;
if (!"".equals(param) && param != null){
key = token + "-" + path + "-" + param_value;//生成key
}else {
key = token + "-" + path;//生成key
}
String clientId = getClientId();// 调接口时生成临时value(UUID)
// 用于添加锁,如果添加成功返回true,失败返回false
boolean isSuccess = redisLock.tryLock(key, clientId, lockSeconds);
ApiResult result = new ApiResult();
if (isSuccess) {
LOGGER.info("加锁成功:接口 = {}, key = {}", path, key);
// 获取锁成功
Object obj;
try {
// 执行进程
obj = pjp.proceed();// aop代理链执行的方法
} finally {
// 据key从redis中获取value
if (clientId.equals(redisLock.get(key))) {
// 解锁
redisLock.releaseLock(key, clientId);
LOGGER.info("解锁成功:接口={}, key = {},",path, key);
}
}
return obj;
} else {
// 添加锁失败,认为是重复提交的请求
LOGGER.info("重复请求:接口 = {}, key = {}",path, key);
result.setData("重复提交");
return result;
}
}
private String getClientId() {
return UUID.randomUUID().toString();
}
public static String getRequestBodyData(HttpServletRequest request) throws IOException{
BufferedReader bufferReader = new BufferedReader(request.getReader());
StringBuilder sb = new StringBuilder();
String line = null;
while ((line = bufferReader.readLine()) != null) {
sb.append(line);
}
return sb.toString();
}
}
3)RedisLock 工具类
package com.lihw.lihwtestboot.noRepeatSubmit;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Service;
import java.util.concurrent.TimeUnit;
@Service
public class RedisLock {
private static final Logger logger = LoggerFactory.getLogger(RedisLock.class);
/** 不设置过期时长 */
public final static long NOT_EXPIRE = -1;
@Autowired
private StringRedisTemplate redisTemplate;
/**
* @param lockKey 加锁键
* @param clientId 加锁客户端唯一标识(采用UUID)
* @param seconds 锁过期时间
* @return
*/
public boolean tryLock(String lockKey, String clientId, long seconds) {
if (redisTemplate.opsForValue().setIfAbsent(lockKey, clientId,seconds, TimeUnit.SECONDS)) {
return true;//得到锁
}else{
return false;
}
}
/**
* 与 tryLock 相对应,用作释放锁
*
* @param lockKey
* @param clientId
* @return
*/
public boolean releaseLock(String lockKey, String clientId) {
String currentValue = redisTemplate.opsForValue().get(lockKey);
try {
if (!StringUtils.isEmpty(currentValue) && currentValue.equals(clientId)) {
redisTemplate.opsForValue().getOperations().delete(lockKey);
return true;
}else {
return false;
}
} catch (Exception e) {
logger.error("解锁异常,,{}" , e);
return false;
}
}
/**
* 获取
* @param key
* @return
*/
public String get(String key) {
return get(key, NOT_EXPIRE);
}
public String get(String key, long expire) {
String value = redisTemplate.opsForValue().get(key);
if(expire != NOT_EXPIRE){
redisTemplate.expire(key, expire, TimeUnit.SECONDS);
}
return value;
}
/**
* 删除
* @param key
*/
public void delete(String key) {
redisTemplate.delete(key);
}
}
4)过滤器 + 请求工具类
Filter类
package com.lihw.lihwtestboot.noRepeatSubmit;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.boot.web.servlet.ServletComponentScan;
import javax.servlet.*;
import javax.servlet.annotation.WebFilter;
import javax.servlet.http.HttpServletRequest;
import java.io.IOException;
@ServletComponentScan
@WebFilter(urlPatterns = "/*",filterName = "channelFilter")
public class ChannelFilter implements Filter {
private final Logger logger = LoggerFactory.getLogger(this.getClass());
@Override
public void init(FilterConfig filterConfig) throws ServletException {
}
@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
logger.info("-----------------------Execute filter start---------------------");
// 防止流读取一次后就没有了, 所以需要将流继续写出去
HttpServletRequest httpServletRequest = (HttpServletRequest) servletRequest;
ServletRequest requestWrapper = new BodyReaderHttpServletRequestWrapper(httpServletRequest);
filterChain.doFilter(requestWrapper, servletResponse);
}
}
BodyReaderHttpServletRequestWrapper
对GET和POST请求的获取参数方法进行了封装
package com.lihw.lihwtestboot.noRepeatSubmit;
import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.ServletRequest;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.*;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
public class BodyReaderHttpServletRequestWrapper extends HttpServletRequestWrapper{
/**
* Request请求参数获取处理类
*/
private final byte[] body;
public BodyReaderHttpServletRequestWrapper(HttpServletRequest request) throws IOException {
super(request);
String sessionStream = getBodyString(request);
body = sessionStream.getBytes(StandardCharsets.UTF_8);
}
/**
* 获取请求Body
*
* @param request
* @return
*/
private String getBodyString(final ServletRequest request) {
StringBuilder sb = new StringBuilder();
InputStream inputStream = null;
BufferedReader reader = null;
try {
inputStream = cloneInputStream(request.getInputStream());
reader = new BufferedReader(new InputStreamReader(inputStream, Charset.forName("UTF-8")));
String line = "";
while ((line = reader.readLine()) != null) {
sb.append(line);
}
} catch (IOException e) {
e.printStackTrace();
} finally {
if (inputStream != null) {
try {
inputStream.close();
} catch (IOException e) {
e.printStackTrace();
}
}
if (reader != null) {
try {
reader.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
return sb.toString();
}
public String getBodyString() {
return new String(body, StandardCharsets.UTF_8);
}
/**
* Description: 复制输入流
*
* @param inputStream
* @return
*/
public InputStream cloneInputStream(ServletInputStream inputStream) {
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
byte[] buffer = new byte[1024];
int len;
try {
while ((len = inputStream.read(buffer)) > -1) {
byteArrayOutputStream.write(buffer, 0, len);
}
byteArrayOutputStream.flush();
} catch (IOException e) {
e.printStackTrace();
}
InputStream byteArrayInputStream = new ByteArrayInputStream(byteArrayOutputStream.toByteArray());
return byteArrayInputStream;
}
@Override
public BufferedReader getReader() throws IOException {
return new BufferedReader(new InputStreamReader(getInputStream()));
}
@Override
public ServletInputStream getInputStream() throws IOException {
final ByteArrayInputStream bais = new ByteArrayInputStream(body);
return new ServletInputStream() {
@Override
public int read() throws IOException {
return bais.read();
}
@Override
public boolean isFinished() {
return false;
}
@Override
public boolean isReady() {
return false;
}
@Override
public void setReadListener(ReadListener readListener) {
}
};
}
}
5)测试Controller
package com.lihw.lihwtestboot.noRepeatSubmit;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*;
import javax.validation.constraints.NotEmpty;
@RestController
@RequestMapping("/api")
@Validated
public class noRepeatSubmitController {
@GetMapping("/subscribe/{channel}")
@NoRepeatSubmit(seconds = 10,scanParam = "username")
public ApiResult subscribe(@RequestHeader(name = "uid") String phone,@RequestHeader(name = "username") String username,@PathVariable("channel") @NotEmpty(message = "channel不能为空") String channel) {
System.out.println("phone=" + phone);
System.out.println("username=" + username);
System.out.println("channel=" + channel);
try {
Thread.sleep(5000);//模拟耗时
} catch (InterruptedException e) {
e.printStackTrace();
}
return new ApiResult("success","data");
}
}
6)测试结果
重复点击