赞
踩
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>
</dependency>
/** * @author jigua * @version 1.0 * @className LimitType * @description * @create 2022/6/17 14:09 */ public enum LimitType { /** * 默认策略全局限流 */ DEFAULT, /** * 根据IP地址限流 */ IP }
import java.lang.annotation.*; /** * @author jigua * @version 1.0 * @className RateLimiter * @description * @create 2022/6/17 14:10 */ @Target(ElementType.METHOD) @Retention(RetentionPolicy.RUNTIME) @Documented public @interface RateLimiter { /** * 限流key */ String key() default "rate_limit:"; /** * 限流时间,单位秒 */ int time() default 30; /** * 限流次数 */ int count() default 300; /** * 限流类型 */ LimitType limitType() default LimitType.DEFAULT; }
import com.fasterxml.jackson.annotation.JsonAutoDetect; import com.fasterxml.jackson.annotation.PropertyAccessor; import com.fasterxml.jackson.databind.ObjectMapper; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.core.io.ClassPathResource; import org.springframework.data.redis.connection.RedisConnectionFactory; import org.springframework.data.redis.core.RedisTemplate; import org.springframework.data.redis.core.script.DefaultRedisScript; import org.springframework.data.redis.serializer.Jackson2JsonRedisSerializer; import org.springframework.scripting.support.ResourceScriptSource; /** * @author jigua * @version 1.0 * @className RedisConfig * @description 修改 RedisTemplate 序列化方案 * @create 2022/6/17 14:12 */ @Configuration public class RedisConfig { @Bean public RedisTemplate<Object, Object> redisTemplate(RedisConnectionFactory connectionFactory) { RedisTemplate<Object, Object> redisTemplate = new RedisTemplate<>(); redisTemplate.setConnectionFactory(connectionFactory); // 使用Jackson2JsonRedisSerialize 替换默认序列化(默认采用的是JDK序列化) //JdkSerializationRedisSerializer序列化时有可能会出现一些莫名其妙的前缀导致读取失败 Jackson2JsonRedisSerializer<Object> jackson2JsonRedisSerializer = new Jackson2JsonRedisSerializer<>(Object.class); ObjectMapper om = new ObjectMapper(); om.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY); om.enableDefaultTyping(ObjectMapper.DefaultTyping.NON_FINAL); jackson2JsonRedisSerializer.setObjectMapper(om); redisTemplate.setKeySerializer(jackson2JsonRedisSerializer); redisTemplate.setValueSerializer(jackson2JsonRedisSerializer); redisTemplate.setHashKeySerializer(jackson2JsonRedisSerializer); redisTemplate.setHashValueSerializer(jackson2JsonRedisSerializer); System.out.println("redis加载完成"); return redisTemplate; } @Bean public DefaultRedisScript<Long> limitScript() { DefaultRedisScript<Long> redisScript = new DefaultRedisScript<>(); redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("lua/limit.lua"))); redisScript.setResultType(Long.class); System.out.println("lua脚本加载完成"); return redisScript; } }
local key = KEYS[1]
local count = tonumber(ARGV[1])
local time = tonumber(ARGV[2])
local current = redis.call('get', key)
if current and tonumber(current) > count then
return tonumber(current)
end
current = redis.call('incr', key)
if tonumber(current) == 1 then
redis.call('expire', key, time)
end
return tonumber(current)
使用lua脚本的两种思路
脚本解释
import com.example.huibaozi.util.IpUtil; import com.google.protobuf.ServiceException; import org.aspectj.lang.JoinPoint; import org.aspectj.lang.annotation.Aspect; import org.aspectj.lang.annotation.Before; import org.aspectj.lang.reflect.MethodSignature; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.data.redis.core.RedisTemplate; import org.springframework.data.redis.core.script.RedisScript; import org.springframework.stereotype.Component; import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.ServletRequestAttributes; import java.lang.reflect.Method; import java.util.Collections; import java.util.List; /** * @author jigua * @version 1.0 * @className RateLimiterAspect * @description * @create 2022/6/17 14:22 */ @Aspect @Component public class RateLimiterAspect { private static final Logger log = LoggerFactory.getLogger(RateLimiterAspect.class); @Autowired private RedisTemplate<Object, Object> redisTemplate; @Autowired private RedisScript<Long> limitScript; @Before("@annotation(rateLimiter)") public void doBefore(JoinPoint point, RateLimiter rateLimiter) throws Throwable { String key = rateLimiter.key(); int time = rateLimiter.time(); int count = rateLimiter.count(); String combineKey = getCombineKey(rateLimiter, point); List<Object> keys = Collections.singletonList(combineKey); try { Long number = redisTemplate.execute(limitScript, keys, count, time); if (number == null || number.intValue() > count) { throw new ServiceException("访问过于频繁,请稍候再试"); } log.info("限制请求'{}',当前请求'{}',缓存key'{}'", count, number.intValue(), key); } catch (ServiceException e) { throw e; } catch (Exception e) { throw new RuntimeException("服务器限流异常,请稍候再试"); } } public String getCombineKey(RateLimiter rateLimiter, JoinPoint point) { StringBuffer stringBuffer = new StringBuffer(rateLimiter.key()); if (rateLimiter.limitType() == LimitType.IP) { stringBuffer.append(IpUtil.getIpAddr(((ServletRequestAttributes) RequestContextHolder.currentRequestAttributes()).getRequest())).append("-"); } MethodSignature signature = (MethodSignature) point.getSignature(); Method method = signature.getMethod(); Class<?> targetClass = method.getDeclaringClass(); stringBuffer.append(targetClass.getName()).append("-").append(method.getName()); return stringBuffer.toString(); } }
import com.google.protobuf.ServiceException; import org.springframework.web.bind.annotation.ExceptionHandler; import org.springframework.web.bind.annotation.RestControllerAdvice; import java.util.HashMap; import java.util.Map; /** * @author jigua * @version 1.0 * @className GlobalException * @description * @create 2022/6/17 14:27 */ @RestControllerAdvice public class GlobalException { @ExceptionHandler(ServiceException.class) public Map<String,Object> serviceException(ServiceException e) { Map<String,Object> map = new HashMap<>(2); map.put("status",500); map.put("message","访问人数过多"); return map; } }
基于ip的测试
@GetMapping("/testIp")
@ResponseBody
@RateLimiter(time = 50, count = 3, limitType = LimitType.IP)
public String testIp() throws Exception {
return "success";
}
测试时有这样的情况不要慌,把请求路径中的localhost替换成ip就好了
基于全局的测试
@GetMapping("/testService")
@ResponseBody
@RateLimiter(time = 50, count = 3, limitType = LimitType.DEFAULT)
public JSONObject testService() throws Exception {
JSONObject list = new JSONObject();
list.put("1", 1);
return list;
}
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。