package com.sinosoftgz.starter.jwt.filter;

import com.alibaba.fastjson.JSON;
import com.auth0.jwt.JWT;
import com.auth0.jwt.interfaces.DecodedJWT;
import com.sinosoftgz.global.common.response.BaseResponse;
import com.sinosoftgz.global.common.response.enums.CommonResponseCodeEnum;
import com.sinosoftgz.starter.jwt.properties.JwtProperties;
import com.sinosoftgz.starter.jwt.utils.JwtUtils;
import com.sinosoftgz.starter.utils.lang.Lang;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.util.AntPathMatcher;
import org.springframework.util.PathMatcher;

import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.io.PrintWriter;

/**
 * Created by Roney on 2021/1/4 20:58.
 */
public class JwtAuthenticationFilter implements Filter {


    private JwtUtils jwtUtils;

    private JwtProperties jwtProperties;

    public JwtAuthenticationFilter(JwtUtils jwtUtils, JwtProperties jwtProperties) {
        this.jwtUtils = jwtUtils;
        this.jwtProperties = jwtProperties;
    }

    PathMatcher matcher = new AntPathMatcher();

    @Override
    public void init(FilterConfig filterConfig) throws ServletException {

    }

    @Override
    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
        HttpServletRequest request = (HttpServletRequest) servletRequest;
        HttpServletResponse response = (HttpServletResponse) servletResponse;
        if (excludeUrls(request, response, filterChain)) return;
        String token = getJwtToken(request);
        if (StringUtils.isEmpty(token)) {
            forbidden(response);
            return;
        }
        /**
         * 校验token
         */
        if (!jwtUtils.expire(token)) {
            forbidden(response);
            return;
        }
        filterChain.doFilter(request, response);
    }

    private void forbidden(HttpServletResponse response) throws IOException {
        response.setContentType("application/json");
        response.setCharacterEncoding("UTF-8");
        PrintWriter writer = response.getWriter();
        writer.write(JSON.toJSONString(BaseResponse.forbidden(CommonResponseCodeEnum.FORBIDDEN.getResultMsg())));
        writer.flush();
        writer.close();
    }


    @Override
    public void destroy() {

    }

    private String getJwtToken(HttpServletRequest request) {
        String token = request.getHeader(jwtProperties.getHeaderKeyOfToken());
        if (StringUtils.isEmpty(token)) {
            token = request.getParameter(jwtProperties.getHeaderKeyOfToken());
        }
        return token;
    }

    /**
     * 如果在排除掉的url里面，直接放行即可
     *
     * @param request
     * @param response
     * @param filterChain
     * @return
     * @throws IOException
     * @throws ServletException
     */
    private boolean excludeUrls(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws IOException, ServletException {
        String requestUrl = request.getRequestURI();
        for (String excludeUrl : jwtProperties.getUrlExcludes()) {
            excludeUrl = request.getContextPath() + excludeUrl;
            /**
             * 不需要登录认证的URL直接过滤
             */
            if (matcher.match(excludeUrl, requestUrl)) {
                filterChain.doFilter(request, response);
                return true;
            }
            if (requestUrl.indexOf(excludeUrl) > -1) {
                filterChain.doFilter(request, response);
                return true;
            }
        }
        return false;
    }
}
