Implementation of Springboot filter prohibiting frequent ip access function

created at 07-26-2021 views: 1

introduction

When developing a Web project, filters are often needed to process some requests, including character set conversion, recording request logs, and so on. In the previous Web development, we used to configure the filter in web.xml, but in SpringBoot, Bing did not have this configuration file, how to operate it?

1 Write a filter:

import lombok.extern.slf4j.Slf4j;

import javax.servlet.*;
import javax.servlet.annotation.WebFilter;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.Iterator;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

@Slf4j
@WebFilter(urlPatterns="/dyflight/*")
public class IpFilter implements Filter{

  /**
   * default limit time (unit: ms) 3600000,3600(s),
   */
  private static final long LIMITED_TIME_MILLIS = 10 * 1000;

  /**
   * The user has the highest continuous access threshold. If this value is exceeded, the IP that is considered to be a malicious operation will be restricted.
   */
  private static final int LIMIT_NUMBER = 5;

  /**
   * The minimum safe time for user access. If the number of visits during this time is greater than the threshold, it will be recorded as a malicious IP, otherwise it will be regarded as a normal visit
   */
  private static final int MIN_SAFE_TIME = 5000;

  private FilterConfig config;

  @Override
  public void init(FilterConfig filterConfig) throws ServletException {
    this.config = filterConfig;  //set filterConfig attribute
  }

  /* (non-Javadoc)
   * @see javax.servlet.Filter#doFilter(javax.servlet.ServletRequest, javax.servlet.ServletResponse, javax.servlet.FilterChain)
   */
  @SuppressWarnings("unchecked")
  @Override
  public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain chain)
      throws IOException, ServletException {
    HttpServletRequest request = (HttpServletRequest) servletRequest;
    HttpServletResponse response = (HttpServletResponse) servletResponse;
    ServletContext context = config.getServletContext();
    // Get restricted IP storage: store restricted IP information
    //Map<String, Long> limitedIpMap = (Map<String, Long>) context.getAttribute("limitedIpMap");
    ConcurrentHashMap<String ,Long> limitedIpMap = (ConcurrentHashMap<String, Long>) context.getAttribute("limitedIpMap");
    // Filter restricted IP
    filterLimitedIpMap(limitedIpMap);
    // Get user IP
    String ip = IPUtil.getRemoteIpAddr(request);
    System.err.println("ip:"+ip);
    // Determine whether it is a restricted IP, if it is, then jump to the exception page
    if (isLimitedIP(limitedIpMap, ip)) {
      long limitedTime = limitedIpMap.get(ip) - System.currentTimeMillis();
      // Remaining limit time (the conversion from milliseconds to seconds will definitely have some errors, but it can basically be ignored)
      request.setAttribute("remainingTime", ((limitedTime / 1000) + (limitedTime % 1000 > 0 ? 1 : 0)));
      System.err.println("IP access is too frequent:"+ip);
      throw new RuntimeException("IP access is too frequent");
    }
    // Get IP storage
    ConcurrentHashMap<String, Long[]> ipMap = (ConcurrentHashMap<String, Long[]>) context.getAttribute("ipMap");
    // Determine whether the current IP exists in the memory, if not, it is the first access, and initialize the IP
    // If there is a current ip, verify the number of visits of the current ip
    // If it is greater than the limit threshold, judge the time to reach the threshold, if it is not greater than [minimum safe time for user access], it will be regarded as a malicious visit and jump to the abnormal page
    if (ipMap.containsKey(ip)) {
      Long[] ipInfo = ipMap.get(ip);
      ipInfo[0] = ipInfo[0] + 1;
      log.debug("present [" + (ipInfo[0]) + "] accesses");
      if (ipInfo[0] > LIMIT_NUMBER) {
        Long ipAccessTime = ipInfo[1];
        Long currentTimeMillis = System.currentTimeMillis();

        log.debug("IP access is too frequent:currentTimeMillis: "+currentTimeMillis+" - ipAccessTime:"+ipAccessTime+" : " + (currentTimeMillis - ipAccessTime) + "<="+ MIN_SAFE_TIME);

        if (currentTimeMillis - ipAccessTime <= MIN_SAFE_TIME) {
          limitedIpMap.put(ip, currentTimeMillis + LIMITED_TIME_MILLIS);
          request.setAttribute("remainingTime", LIMITED_TIME_MILLIS);

          log.debug("IP access is too frequent:LIMITED_TIME_MILLIS:"+LIMITED_TIME_MILLIS);

          log.debug("IP access is too frequent:"+ip);
          throw new RuntimeException("IP access is too frequent");
        } else {
          initIpVisitsNumber(ipMap, ip);
        }
      }
    } else {
      initIpVisitsNumber(ipMap, ip);
      System.out.println("Your first visit to the site");
    }
    context.setAttribute("ipMap", ipMap);
    chain.doFilter(request, response);
  }

  @Override
  public void destroy() {
    // TODO Auto-generated method stub
  }

  /**
   * @Description Filter restricted IPs and remove restricted IPs that have expired
   * @param limitedIpMap
   */
  private void filterLimitedIpMap(ConcurrentHashMap<String, Long> limitedIpMap) {
    if (limitedIpMap == null) {
      return;
    }
    Set<String> keys = limitedIpMap.keySet();
    Iterator<String> keyIt = keys.iterator();
    long currentTimeMillis = System.currentTimeMillis();
    while (keyIt.hasNext()) {
      long expireTimeMillis = limitedIpMap.get(keyIt.next());
      log.debug("expireTimeMillis <= currentTimeMillis:"+ expireTimeMillis+" <="+ currentTimeMillis);
      if (expireTimeMillis <= currentTimeMillis) {
        keyIt.remove();
      }
    }
  }

  /**
   * @Description Is it a restricted IP
   * @param limitedIpMap
   * @param ip
   * @return true : Restricted | false: normal
   */
  private boolean isLimitedIP(ConcurrentHashMap<String, Long> limitedIpMap, String ip) {
    if (limitedIpMap == null || ip == null) {
      // Not restricted
      return false;
    }
    Set<String> keys = limitedIpMap.keySet();
    Iterator<String> keyIt = keys.iterator();
    while (keyIt.hasNext()) {
      String key = keyIt.next();
      if (key.equals(ip)) {
        // Restricted IP
        return true;
      }
    }
    return false;
  }

  /**
   * Initialize user visit times and visit time
   *
   * @param ipMap
   * @param ip
   */
  private void initIpVisitsNumber(ConcurrentHashMap<String, Long[]> ipMap, String ip) {
    Long[] ipInfo = new Long[2];
    ipInfo[0] = 0L;// Visits
    ipInfo[1] = System.currentTimeMillis();// First visit time
    ipMap.put(ip, ipInfo);
  }
}

2 Create a listener: need to initialize two containers:

import lombok.extern.slf4j.Slf4j;

import javax.servlet.ServletContext;
import javax.servlet.ServletContextEvent;
import javax.servlet.ServletContextListener;
import javax.servlet.annotation.WebListener;
import java.util.concurrent.ConcurrentHashMap;


@Slf4j
@WebListener
public class MyApplicationListener implements ServletContextListener {

  @Override
  public void contextInitialized(ServletContextEvent sce) {
    log.debug("liting: contextInitialized");
    log.debug("MyApplicationListener初始化成功");
    ServletContext context = sce.getServletContext();
    // IP存储器
    ConcurrentHashMap<String, Long[]> ipMap = new ConcurrentHashMap<>();
    context.setAttribute("ipMap", ipMap);
    // Restricted IP storage: Store restricted IP information
    ConcurrentHashMap<String, Long> limitedIpMap = new ConcurrentHashMap<String, Long>();
    context.setAttribute("limitedIpMap", limitedIpMap);
    log.debug("ipmap:"+ipMap.toString()+";limitedIpMap:"+limitedIpMap.toString()+"Initialization successful. . . . .");
  }
  @Override
  public void contextDestroyed(ServletContextEvent sce) {
    // TODO Auto-generated method stub
  }
}

3 iputil

import javax.servlet.http.HttpServletRequest;
import java.net.InetAddress;
import java.net.UnknownHostException;

public class IPUtil {

  public static String getRemoteIpAddr(HttpServletRequest request) {
    String ip = request.getHeader("x-forwarded-for");
    if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
      ip = request.getHeader("Proxy-Client-IP");
    }
    if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
      ip = request.getHeader("WL-Proxy-Client-IP");
    }
    if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
      ip = request.getHeader("HTTP_CLIENT_IP");
    }
    if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
      ip = request.getHeader("HTTP_X_FORWARDED_FOR");
    }
    if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
      ip = request.getRemoteAddr();
      if("127.0.0.1".equals(ip)||"0:0:0:0:0:0:0:1".equals(ip)){
        //Get the IP configured by the machine according to the network card
        InetAddress inet=null;
        try {
          inet = InetAddress.getLocalHost();
        } catch (UnknownHostException e) {
          e.printStackTrace();
        }
        ip= inet.getHostAddress();
      }
    }
    return ip;
  }


}

4 configuration

Package scanning with filters and listeners added to the springboot startup class

@ServletComponentScan(basePackages="cn.xxx.common")

spring web.xml

filter

<filter>
    <filter-name>ipFilter</filter-name>
    <filter-class>com.xxxx.common.filter.IpFilter</filter-class>
  </filter>
  <filter-mapping>
    <filter-name>ipFilter</filter-name>
    <url-pattern>/dyflight/**</url-pattern>
  </filter-mapping>

Listener:

<listener>
    <listener-class>com.xxxx.common.Listener.MyApplicationListener</listener-class>
</listener>
created at:07-26-2021
edited at: 07-26-2021: