aop实现接口访问频率限制

本文主要介绍通过注解形式给接口加上访问频率限制,避免别有用心之人恶意调用,原理有点类似于前端的节流,细节如下:

首先需要定义一个注解,用来标记controller方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
package com.test.common.annotation;

import com.test.common.constant.PreventStrategy;

import java.lang.annotation.*;


@Documented
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface Prevent {
/**
* 限制的时间值(秒)默认10s
*/
long value() default 10;

/**
* 限制规定时间内访问次数,默认只能访问一次
*/
long times() default 1;

/**
* 提示
*/
String message() default "";

/**
* 策略
*/
PreventStrategy strategy() default PreventStrategy.DEFAULT;
}

PreventStrategy 是自定义的策略枚举,默认使用默认值

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
package com.test.common.aop;

import cn.hutool.extra.servlet.ServletUtil;
import com.baomidou.mybatisplus.core.toolkit.StringPool;
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import com.test.common.annotation.Prevent;
import com.test.common.api.ResultCode;
import com.test.common.constant.CommonConstant;
import com.test.common.constant.PreventStrategy;
import com.test.common.exception.ServiceException;
import com.test.common.util.redis.RedisUtil;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import javax.servlet.http.HttpServletRequest;
import java.lang.reflect.Method;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.Objects;
import java.util.concurrent.TimeUnit;


@Aspect
@Component
public class PreventAop {

/**
* 切入点
*/
@Pointcut("@annotation(com.test.common.annotation.Prevent)")
public void pointcut() {}

/**
* 处理前
*/
@Before("pointcut()")
public void joinPoint(JoinPoint joinPoint) throws Exception {
// 获取调用者ip
RequestAttributes requestAttributes = RequestContextHolder.currentRequestAttributes();
HttpServletRequest httpServletRequest = ((ServletRequestAttributes) requestAttributes).getRequest();
String ip = ServletUtil.getClientIP(httpServletRequest);
// 获取调用接口方法名
MethodSignature methodSignature = (MethodSignature) joinPoint.getSignature();
// 获取该接口方法
Method method = joinPoint.getTarget().getClass().getMethod(
methodSignature.getName(),
methodSignature.getParameterTypes());
// 获取到方法名
String className = method.getDeclaringClass().getName();
className = className.substring(className.lastIndexOf(StringPool.DOT)+1);
String methodFullName = String.format("%s.%s", className, method.getName());
// 获取该接口上的prevent注解(为了使用该注解内的参数)
Prevent preventAnnotation = method.getAnnotation(Prevent.class);
// 执行对应策略
entrance(preventAnnotation, ip, methodFullName);
}

/**
* 通过prevent注册判断执行策略
* @param prevent 该接口的prevent注解对象
* @param userIp 访问该接口的用户ip
* @param methodFullName 该接口方法名
*/
private void entrance(Prevent prevent, String userIp, String methodFullName) throws Exception {
PreventStrategy strategy = prevent.strategy();
if (Objects.requireNonNull(strategy) == PreventStrategy.DEFAULT) {
defaultHandle(userIp, prevent, methodFullName);
} else {
throw new ServiceException(ResultCode.FREQUENT_REQUEST);
}
}

private void defaultHandle(String userIp, Prevent prevent, String methodFullName) throws Exception {
// 加密用户ip(避免ip存在一些特殊字符作为redis的key不合法)
String base64Ip = toBase64String(userIp);
long expire = prevent.value();
long times = prevent.times();
String key = String.format("%s:ip:%s:%s", CommonConstant.REDIS_KEY_PREFIX, base64Ip, methodFullName);
// 限制特定时间内访问特定次数
Long count = RedisUtil.redis.opsForValue().increment(key, 1);
if (Objects.isNull(count)) {
return;
}
// 如果访问次数为1,则重置访问限制时间(即redis超时时间)
if (count == 1) {
RedisUtil.redis.expire(key, expire, TimeUnit.SECONDS);
}
// 如果访问次数超出访问限制次数,则禁止访问
if (count > times) {
String errorMessage =
StringUtils.isNotEmpty(prevent.message()) ? prevent.message() : ResultCode.FREQUENT_REQUEST.getMsg();
throw new ServiceException(ResultCode.FREQUENT_REQUEST.getCode(), errorMessage);
}
}

/**
* 对象转换为base64字符串
* @param obj 对象值
* @return base64字符串
*/
private String toBase64String(String obj) throws Exception {
if (StringUtils.isEmpty(obj)) {
return null;
}
Base64.Encoder encoder = Base64.getEncoder();
byte[] bytes = obj.getBytes(StandardCharsets.UTF_8);
return encoder.encodeToString(bytes);
}
}

上面的注释已经很详细了,默认策略会统计访问次数,设置超时时间,次数到达上限会抛出异常,直到key过期解除限制。

使用时标记在接口方法上就行了,可根据需要覆盖@Prevent的默认值

1
2
3
4
5
@PostMapping("/form")
@Prevent()
public Result<String> addForm(@Valid @RequestBody FormCreateReqVo reqVo) {
return success(projectInspectionLotClientService.addForm(reqVo));
}