侧边栏壁纸
博主头像
孔子说JAVA博主等级

成功只是一只沦落在鸡窝里的鹰,成功永远属于自信且有毅力的人!

  • 累计撰写 352 篇文章
  • 累计创建 135 个标签
  • 累计收到 10 条评论

目 录CONTENT

文章目录

spring web项目使用过滤器防止sql注入

孔子说JAVA
2022-09-25 / 0 评论 / 0 点赞 / 72 阅读 / 10,993 字 / 正在检测是否收录...
广告 广告

公司渗透测试中发现了输入框有SQL注入的风险,该项目为spring web项目,通过增加全局过滤器,对请求参数进行解析,对其中在SQL过滤器中禁用的参数进行提示。

1、过滤器实现方式一

1.1 过滤器类

package com.kz.common.utils;
 
import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.Enumeration;
 
/**
 * sql注入过滤器
 **/
public class SqlInjectionFilter implements Filter{
 
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
 
        HttpServletRequest req = (HttpServletRequest) request;
        HttpServletResponse res = (HttpServletResponse) response;
        //获得所有请求参数名
        Enumeration params = req.getParameterNames();
 
        String sql = "";
        while (params.hasMoreElements()) {
            //得到参数名
            String name = params.nextElement().toString();
            //System.out.println("name===========================" + name + "--");
            //得到参数对应值
            String[] value = req.getParameterValues(name);
            for (int i = 0; i < value.length; i++) {
                sql = sql + value[i];
            }
        }
        //有sql关键字,跳转到error.html
        if (sqlValidate(sql)) {
            throw new IOException("您发送请求中的参数中含有非法字符");
        } else {
            chain.doFilter(req, res);
        }
    }
 
    //效验
    protected static boolean sqlValidate(String str) {
        str = str.toLowerCase();//统一转为小写
        //String badStr = "'|and|exec|execute|insert|select|delete|update|count|drop|*|%|chr|mid|master|truncate|char|declare|sitename|net user|xp_cmdshell|;|or|-|+|,|like";
        String badStr = "'|and|exec|execute|insert|create|drop|table|from|grant|use|group_concat|column_name|" +
                "information_schema.columns|table_schema|union|where|select|delete|update|order|by|count|*|" +
                "chr|mid|master|truncate|char|declare|or|;|-|--|+|,|like|//|/|%|#";//过滤掉的sql关键字,可以手动添加
        String[] badStrs = badStr.split("\\|");
        for (int i = 0; i < badStrs.length; i++) {
            if (str.indexOf(badStrs[i]) !=-1) {
                return true;
            }
        }
        return false;
    }
 
    public void init(FilterConfig filterConfig) throws ServletException {
        //throw new UnsupportedOperationException("Not supported yet.");
    }
 
    public void destroy() {
        //throw new UnsupportedOperationException("Not supported yet.");
    }
}

init(FilterConfig filterConfig) 为过滤器初始化方法,读取静态资源所在的路径

1.2 web.xml配置

在web.xml中添加过滤器类的配置。

 <filter>
    <filter-name>SqlInjectionFilter</filter-name>
    <filter-class>com.kz.common.utils.SqlInjectionFilter</filter-class>
  </filter>
  <filter-mapping>
    <filter-name>SqlInjectionFilter</filter-name>
    <url-pattern>/*</url-pattern>
  </filter-mapping>

配置各节点介绍:

节点名 介绍
<filter> 指定一个过滤器
<filter-name> 用于为过滤器指定一个名字,该元素的内容不能为空
<filter-class> 指定过滤器的完整的限定类名
<init-param> 为过滤器指定初始化参数。在过滤器中,可以使用FilterConfig接口对象来访问初始化参数
<param-name> <init-param>的子元素,指定参数的名字
<param-value> <init-param>的子元素,指定参数的值
<filter-mapping> 设置一个Filter所负责拦截的资源。可通过Servlet名称或资源访问的请求路径指定
<filter-name> 子元素用于设置filter的注册名称。该值必须是在<filter>元素中声明过的过滤器的名字
<url-pattern> 设置 filter 所拦截的请求路径(过滤器关联的URL样式)
<servlet-name> 指定过滤器所拦截的Servlet名称
<dispatcher> 指定过滤器所拦截的资源被 Servlet 容器调用的方式,默认REQUEST

2、过滤器实现方式二

通过过滤器 SqlInjectFilter 和 请求参数封装类MyRequestWrapper(用于过滤器中获取POST请求参数) 结合使用。

2.1 过滤器类

SqlInjectFilter.java(实现方式一)

import com.alibaba.fastjson.JSONObject;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Component;

import javax.annotation.Resource;
import javax.servlet.*;
import javax.servlet.annotation.WebFilter;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.BufferedReader;
import java.io.IOException;
import java.util.*;

/**
 * sql注入过滤器
 */
@Slf4j
@Component
@WebFilter(urlPatterns = "/*", filterName = "SQLInjection")
public class SqlInjectFilter implements Filter {

    private static final String SQL_REGX = ".*(\\b(select|update|and|or|delete|insert|trancate|char|into|substr|ascii|declare|exec|count|master|drop|execute)\\b).*";

    /**
     * springmvc启动时自动装配json处理类
     */
    @Resource
    private ObjectMapper objectMapper;

    @Override
    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain chain) throws IOException, ServletException {
        HttpServletRequest req = (HttpServletRequest) servletRequest;

        // 防止流读取一次后就没有了, 所以需要将流继续写出去
        MyRequestWrapper requestWrapper = new MyRequestWrapper(req);

        // 获取请求参数
        Map<String, Object> paramsMaps = new TreeMap<>();
        if ("POST".equals(req.getMethod().toUpperCase())) {
            String body = requestWrapper.getBody();
            paramsMaps = JSONObject.parseObject(body, TreeMap.class);
        } else {
            Map<String, String[]> parameterMap = requestWrapper.getParameterMap();
            Set<Map.Entry<String, String[]>> entries = parameterMap.entrySet();
            for (Map.Entry<String, String[]> next : entries) {
                paramsMaps.put(next.getKey(), next.getValue()[0]);
            }
        }

        // 校验SQL注入
        for (Object o : paramsMaps.entrySet()) {
            Map.Entry entry = (Map.Entry) o;
            Object value = entry.getValue();
            if (value != null) {
                boolean isValid = checkSqlInject(value.toString(), servletResponse);
                if (!isValid) {
                    return;
                }
            }
        }

        chain.doFilter(requestWrapper, servletResponse);
    }

    //获取request请求body中参数
    public static String getBodyString(BufferedReader br) {
        String inputLine;
        String str = "";
        try {
            while ((inputLine = br.readLine()) != null) {
                str += inputLine;
            }
            br.close();
        } catch (IOException e) {
            System.out.println("IOException: " + e);
        }
        return str;
    }

    /**
     * 检查SQL注入
     *
     * @param value           参数值
     * @param servletResponse 相应实例
     * @throws IOException      IO异常
     */
    private boolean checkSqlInject(String value, ServletResponse servletResponse) throws IOException {
        if (null != value && value.matches(SQL_REGX)) {
            log.error("您输入的参数有非法字符,请输入正确的参数");
            HttpServletResponse response = (HttpServletResponse) servletResponse;

            Map<String, String> rsp = new HashMap<>();
            rsp.put("code", HttpStatus.BAD_REQUEST.value() + "");
            rsp.put("message", "您输入的参数有非法字符,请输入正确的参数!");

            response.setStatus(HttpStatus.OK.value());
            response.setContentType("application/json;charset=UTF-8");
            response.getWriter().write(objectMapper.writeValueAsString(rsp));
            response.getWriter().flush();
            response.getWriter().close();
            return false;
        }
        return true;
    }

    @Override
    public void destroy() {
    }

}

SqlInjectFilter.java(实现方式二)

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import lombok.extern.slf4j.Slf4j;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.core.Ordered;
import org.springframework.core.annotation.Order;
import org.springframework.http.HttpStatus;
import org.springframework.web.filter.OncePerRequestFilter;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;

/**
 * sql注入过滤器
 */
@Slf4j
@Order(Ordered.HIGHEST_PRECEDENCE + 8)
@ConditionalOnClass(WebMvcConfigurer.class)
public class SqlInjectFilter extends OncePerRequestFilter {

    private static final String SQL_REGX = ".*(\\b(select|update|and|or|delete|insert|trancate|char|into|substr|ascii|declare|exec|count|master|drop|execute)\\b).*";

    @Override
    protected void doFilterInternal(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse, FilterChain filterChain) throws ServletException, IOException {
        // 防止流读取一次后就没有了, 所以需要将流继续写出去
        MyRequestWrapper requestWrapper = new MyRequestWrapper(httpServletRequest);

        // 获取请求参数
        Map<String, Object> paramsMaps = new TreeMap<>();
        if ("POST".equals(httpServletRequest.getMethod().toUpperCase())) {
            String body = requestWrapper.getBody();
            paramsMaps = JSONObject.parseObject(body, TreeMap.class);
        } else {
            Map<String, String[]> parameterMap = requestWrapper.getParameterMap();
            Set<Map.Entry<String, String[]>> entries = parameterMap.entrySet();
            for (Map.Entry<String, String[]> next : entries) {
                paramsMaps.put(next.getKey(), next.getValue()[0]);
            }
        }

        // 校验SQL注入
        for (Object o : paramsMaps.entrySet()) {
            Map.Entry entry = (Map.Entry) o;
            Object value = entry.getValue();
            if (value != null) {
                boolean isValid = checkSqlInject(value.toString(), httpServletResponse);
                if (!isValid) {
                    return;
                }
            }
        }

        filterChain.doFilter(requestWrapper, httpServletResponse);
    }


    /**
     * 检查SQL注入
     *
     * @param value           参数值
     * @param servletResponse 相应实例
     * @throws IOException      IO异常
     */
    private boolean checkSqlInject(String value, ServletResponse servletResponse) throws IOException {
        if (null != value && value.matches(SQL_REGX)) {
            log.error("您输入的参数有非法字符,请输入正确的参数");
            HttpServletResponse response = (HttpServletResponse) servletResponse;

            response.setStatus(HttpStatus.OK.value());

            Map<String, String> rsp = new HashMap<>();
            rsp.put("code", HttpStatus.BAD_REQUEST.value() + "");
            rsp.put("message", "您输入的参数有非法字符,请输入正确的参数!");

            response.setContentType("application/json;charset=UTF-8");
            response.getWriter().write(JSON.toJSONString(rsp));
            response.getWriter().flush();
            response.getWriter().close();
            return false;
        }
        return true;
    }

    @Override
    public void destroy() {
    }

}

2.2 请求参数封装类

请求参数封装类 MyRequestWrapper,主要用于过滤器中获取POST请求参数。

import org.apache.commons.codec.Charsets;

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.*;
import java.util.Enumeration;
import java.util.Map;

/**
 * 用于过滤器中获取POST请求参数
 */
public class MyRequestWrapper extends HttpServletRequestWrapper {
    private String body;
    public MyRequestWrapper(HttpServletRequest request) throws IOException {
        super(request);
        StringBuilder stringBuilder = new StringBuilder();
        BufferedReader bufferedReader = null;
        try {
            InputStream inputStream = request.getInputStream();
            if (inputStream != null) {
                bufferedReader = new BufferedReader(new InputStreamReader(inputStream,"UTF-8"));
                char[] charBuffer = new char[128];
                int bytesRead = -1;
                while ((bytesRead = bufferedReader.read(charBuffer)) > 0) {
                    stringBuilder.append(charBuffer, 0, bytesRead);
                }
            } else {
                stringBuilder.append("");
            }
        } catch (IOException ex) {
            throw ex;
        } finally {
            if (bufferedReader != null) {
                try {
                    bufferedReader.close();
                } catch (IOException ex) {
                    throw ex;
                }
            }
        }
        body = stringBuilder.toString();
    }

    @Override
    public ServletInputStream getInputStream() throws IOException {
        final ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(body.getBytes("UTF-8"));
        ServletInputStream servletInputStream = new ServletInputStream() {
            @Override
            public boolean isFinished() {
                return false;
            }

            @Override
            public boolean isReady() {
                return false;
            }

            @Override
            public void setReadListener(ReadListener readListener) {

            }

            @Override
            public int read() {
                return byteArrayInputStream.read();
            }
        };
        return servletInputStream;
    }

    @Override
    public BufferedReader getReader() throws IOException {
        return new BufferedReader(new InputStreamReader(this.getInputStream(), Charsets.UTF_8));
    }

    public String getBody() {
        return this.body;
    }

    @Override
    public String getParameter(String name) {
        return super.getParameter(name);
    }

    @Override
    public Map<String, String[]> getParameterMap() {
        return super.getParameterMap();
    }

    @Override
    public Enumeration<String> getParameterNames() {
        return super.getParameterNames();
    }

    @Override
    public String[] getParameterValues(String name) {
        return super.getParameterValues(name);
    }
}
0

评论区