Mybatis拦截器中获取@RequestBody表单的值修改查询SQL
背景:
我们需要获取接口Controller中前端传入的Json对象参数值然后修改本次调用接口的查询sql语句。后台接收参数如果是表单数据的话,通过request.getParameterMap就可以全部获取到了,如果是json对象数据时,我们在过滤器或拦截器里通过request.getInputStream() 读取了request的输入流之后,请求走到controller层时就会报错,问题在于request的输入流只能读取一次不能重复读取。
1.示例:定义Controller查询UserList
@PostMapping("/user/list")
public PageDataInfo<UserInfo> getUserList(@RequestBody ChkReq req) {
PageUtils.startPage(req);
return PageUtils.buildPageDataInfo(userInfoService.getUserList(req));
}
2.定义一个容器,将输入流存储到这个容器里面
@Slf4j
public class RequestWrapper extends HttpServletRequestWrapper {
/**
* 存储body数据的容器
*/
private final byte[] body;
public RequestWrapper(HttpServletRequest request) {
super(request);
// 将body数据存储起来
String bodyStr = getBodyString(request);
body = bodyStr.getBytes(Charset.defaultCharset());
}
public String getBodyString(final ServletRequest request) {
try {
return cloneInputStreamString(request.getInputStream());
} catch (IOException e) {
log.error("", e);
throw new RuntimeException(e);
}
}
public String getBodyString() {
final InputStream inputStream = new ByteArrayInputStream(body);
return cloneInputStreamString(inputStream);
}
private String cloneInputStreamString(InputStream inputStream) {
StringBuilder sb = new StringBuilder();
BufferedReader reader = null;
try {
reader = new BufferedReader(new InputStreamReader(inputStream, Charset.defaultCharset()));
String line;
while ((line = reader.readLine()) != null) {
sb.append(line);
}
} catch (IOException e) {
log.error("", e);
throw new RuntimeException(e);
} finally {
if (reader != null) {
try {
reader.close();
} catch (IOException e) {
log.error("", e);
}
}
}
return sb.toString();
}
@Override
public BufferedReader getReader() throws IOException {
return new BufferedReader(new InputStreamReader(getInputStream()));
}
@Override
public ServletInputStream getInputStream() throws IOException {
final ByteArrayInputStream inputStream = new ByteArrayInputStream(body);
return new ServletInputStream() {
@Override
public int read() throws IOException {
return inputStream.read();
}
@Override
public boolean isFinished() {
return false;
}
@Override
public boolean isReady() {
return false;
}
@Override
public void setReadListener(ReadListener readListener) {
}
};
}
}
3.我们要在过滤器中将原生的HttpServletRequest换成RequestWrapper对象
public class ReplaceStreamFilter implements Filter {
@Override
public void init(FilterConfig filterConfig) throws ServletException {
Filter.super.init(filterConfig);
}
@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
ServletRequest requestWrapper = new RequestWrapper((HttpServletRequest) servletRequest);
filterChain.doFilter(requestWrapper, servletResponse);
}
@Override
public void destroy() {
Filter.super.destroy();
}
}
4.注册过滤器
@Configuration
public class FilterConfig {
@Bean
public FilterRegistrationBean someFilterRegistration() {
FilterRegistrationBean registration = new FilterRegistrationBean();
registration.setFilter(replaceStreamFilter());
registration.addUrlPatterns("/*");
registration.setName("streamFilter");
return registration;
}
@Bean(name = "replaceStreamFilter")
public Filter replaceStreamFilter() {
return new ReplaceStreamFilter();
}
}
5.然后我们可以在拦截器中获取json数据
public class MyRequestInterceptor implements HandlerInterceptor {
private ObjectMapper objectMapper = new ObjectMapper();
@Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
if ("POST".equalsIgnoreCase(request.getMethod()) && request.getContentType() != null && request.getContentType().contains("application/json")) {
/*try {
byte[] requestBodyBytes = readRequestBody(request);
String requestBody = new String(requestBodyBytes, StandardCharsets.UTF_8);*/
try (BufferedReader reader = request.getReader()) {
StringBuilder requestBody = new StringBuilder();
String line;
while ((line = reader.readLine()) != null) {
requestBody.append(line);
}
// 将请求体转换为 ChkReq 对象
ChkReq chkReq = objectMapper.readValue(requestBody.toString(), ChkReq.class);
// 将 ChkReq 对象存储在 ServletRequestAttributes 中
ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
if (attributes != null) {
attributes.getRequest().setAttribute("ChkReq", chkReq);
}
// 继续处理请求
return true;
} catch (IOException e) {
// 处理异常,例如返回错误响应
response.setStatus(HttpServletResponse.SC_BAD_REQUEST);
response.getWriter().write("Invalid JSON data");
return false;
}
}
// 如果不是 JSON 请求或者不是 POST 方法,则继续处理请求
return true;
}
}
6.注册拦截器
@Configuration
public class WebConfig implements WebMvcConfigurer {
@Override
public void addInterceptors(InterceptorRegistry registry) {
registry.addInterceptor(new MyRequestInterceptor()).addPathPatterns("/**"); // 指定需要拦截的路径
}
}
7.在Mybatis拦截器中获取request的值修改sql
@Component
public class MyInterceptor implements InnerInterceptor {
@SneakyThrows
@Override
public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
//InnerInterceptor.super.beforeQuery(executor, ms, parameter, rowBounds, resultHandler, boundSql);
String sql = boundSql.getSql();
System.out.println("sql更新之前:" + sql);
//String condition = " name = '李四' " ;
String condition = " 1 = 1 ";
String name = null;
ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
if (attributes != null) {
ChkReq chkReq = (ChkReq) attributes.getRequest().getAttribute("ChkReq");
if (chkReq != null) {
name = "name = '" + chkReq.getName() + "'";
}
}
PluginUtils.MPBoundSql mpBs = PluginUtils.mpBoundSql(boundSql);
Select select = (Select) CCJSqlParserUtil.parse(sql);
PlainSelect plainSelect = (PlainSelect) select.getSelectBody();
final Expression expression = plainSelect.getWhere();
final Expression envCondition = CCJSqlParserUtil.parseCondExpression(condition);
final Expression envCondition2 = CCJSqlParserUtil.parseCondExpression(name);
if (expression == null) {
plainSelect.setWhere(envCondition);
plainSelect.setWhere(envCondition2);
} else {
AndExpression andExpression = new AndExpression(expression, envCondition);
AndExpression andExpression2 = new AndExpression(andExpression, envCondition2);
plainSelect.setWhere(andExpression2);
}
mpBs.sql(plainSelect.toString());
System.out.println("sql更新之后:" + plainSelect.toString());
}
}