0時に寝て7時に起きた。

spring boot における簡易的なミドルウェアの実装

bff (backend for frontend) の web api サーバーから別の web api サーバーのエンドポイントを呼び出す処理を書いていてただプロキシする処理のためだけに controller を書くのも面倒だなと気付いた。controller よりも低いレイヤで raw request を扱うにはどうしたらいいかを調べてみたら spring の Filter を使うのが最も簡単そうにみえた。その spring の Filter の1つである OncePerRequestFilter を使って簡易的なプロキシの処理を実装してみた。OncePerRequestFilter はリクエストに対して1度だけ呼ばれることが保証された Filter になる。これらのドキュメントがどこにあるかわからなかったので baeldung のチュートリアル What Is OncePerRequestFilter? をみた方が早いかもしれない。

Filter が複数あるときは実行順序を指定したいときは Order というアノテーションで制御できる。

@Component
@Order(30)
public class MyFilter extends OncePerRequestFilter {
  ...
}

OncePerRequestFilter の doFilterInternal メソッドをオーバーライドすれば任意のミドルウェアっぽい処理を実装できる。

    @Override
    protected void doFilterInternal(
        HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
        throws ServletException, IOException {
        // ここに任意の前処理を実装する
        filterChain.doFilter(request, response);
        // ここに任意の後処理を実装する
    }

あとは HttpServletRequest と HttpServletResponse を直接操作してリクエストから必要な情報を取り出して、クライアントでリクエストして返ってきたレスポンスをそのまま返してあげればよい。簡単に実装したものが次になる。クライアントには 以前に少し調べたことを書いた WebClient を使っている。WebClient じゃなければ、もう少しシンプルに書ける気もする。

@Component
@Order(30)
public class RequestForwardFilter extends OncePerRequestFilter {
    @Autowired
    private List<RequestForwardMatcher> matchers;

    @Override
    protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
        var matcher = this.getMatcher(request);
        if (matcher.isEmpty()) {
            filterChain.doFilter(request, response);
            return;
        }
        writeResponse(response, this.forwardRequest(request, matcher.get().getWebClient()));
    }

    private Optional<byte[]> forwardRequest(HttpServletRequest request, WebClient webClient) throws IOException {
        var method = HttpMethod.resolve(request.getMethod().toUpperCase());
        var originalUri = request.getRequestURI();
        var uri = this.getUri(originalUri, request.getQueryString());
        var spec = this.getSpec(method, webClient, uri, request.getReader());
        return spec.retrieve().bodyToMono(byte[].class).blockOptional();
    }

    private void writeResponse(HttpServletResponse response, Optional<byte[]> bytes) throws IOException {
        if (bytes.isEmpty()) {
            return;
        }
        var stream = response.getOutputStream();
        stream.write(bytes.get());
        stream.flush();
        stream.close();
    }

    private byte[] getRequestBody(BufferedReader reader) {
        return reader.lines().collect(Collectors.joining()).getBytes();
    }

    private String getUri(String originalUri, String queryString) {
        var path = originalUri.replace("/api", "");
        if (StringUtils.hasLength(queryString)) {
            return String.format("%s?%s", path, queryString);
        }
        return path;
    }

    private WebClient.RequestHeadersSpec<?> getSpec(
            HttpMethod method, WebClient webClient, String uri, BufferedReader reader) {
        switch (method) {
            case GET:
                return webClient.get().uri(uri);
            case POST:
                return webClient.post()
                        .uri(uri)
                        .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
                        .bodyValue(this.getRequestBody(reader));
            case PUT:
                return webClient.put()
                        .uri(uri)
                        .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
                        .bodyValue(this.getRequestBody(reader));
            case DELETE:
                return webClient.delete().uri(uri);
            default:
                throw new RuntimeException(String.format("Unsupported HTTP method: %s", method));
        }
    }

    private Optional<RequestForwardMatcher> getMatcher(HttpServletRequest request) {
        for (var matcher : this.matchers) {
            if (matcher.match(request)) {
                return Optional.of(matcher);
            }
        }
        return Optional.empty();
    }
}

matcher のサンプル実装。

public class RequestForwardMatcher {
    private static final String ASTERISK = "*";

    private final Pattern pattern;
    private final List<String> methods;
    private final WebClient webClient;

    public RequestForwardMatcher(String pattern, List<String> methods, WebClient webClient) {
        this.pattern = Pattern.compile(pattern);
        this.methods = methods;
        this.webClient = webClient;
    }

    private boolean matchMethod(String method) {
        if (this.methods.contains(ASTERISK)) {
            return true;
        }
        return this.methods.contains(method);
    }

    public boolean match(HttpServletRequest request) {
        if (!matchMethod(request.getMethod())) {
            return false;
        }
        return this.pattern.matcher(request.getRequestURI()).matches();
    }

    public WebClient getWebClient() {
        return this.webClient;
    }
}