编程开源技术交流,分享技术与知识

网站首页 > 开源技术 正文

自己动手写一个服务网关(服务器做网关)

wxchong 2024-08-06 03:14:42 开源技术 23 ℃ 0 评论

什么是网关?为什么需要使用网关?


如图所示,在不使用网关的情况下,我们的服务是直接暴露给服务调用方。当调用方增多,势必需要添加定制化访问权限、校验等逻辑。当添加API网关后,再第三方调用端和服务提供方之间就创建了一面墙,这面墙直接与调用方通信进行权限控制。

本文所实现的网关源码抄袭了---Oh,不对,是借鉴。借鉴了Zuul网关的源码,提炼出其核心思路,实现了一套简单的网关源码,博主将其改名为Eatuul。

题外话

本文是业内能搜到的第一篇自己动手实现网关的文章。博主写的手把手系列的文章,目的是在以最简单的方式,揭露出中间件的核心原理,让读者能够迅速了解实现的核心。需要说明的是,这不是源码分析系列的文章,因此写出来的代码,省去了一些复杂的内容,毕竟大家能理解到该中间件的核心原理即可。如果想看源码分析系列的,请关注博主,后期会将spring、spring boot、dubbo、mybatis等开源框架一一揭示。

正文设计思路

先大致说一下,就是定义一个Servlet接收请求。然后经过preFilter(封装请求参数),routeFilter(转发请求),postFilter(输出内容)。三个过滤器之间,共享request、response以及其他的一些全局变量。如下图所示


# 和真正的Zuul的区别?主要区别有如下几点

(1)Zuul中在异常处理模块,有一个ErrorFilter来处理,博主在实现的时候偷懒了,略去。

(2)Zuul中PreFilters,RoutingFilters,PostFilters默认都实现了一组,具体如下表所示


博主总不可能每一个都给你们实现一遍吧。所以偷懒了,每种只实现一个。但是调用顺序还是不变,按照PreFilters->RoutingFilters->PostFilters的顺序调用

(3)在routeFilters确实有转发请求的Filter,然而博主偷天换日了,改用RestTemplate实现.

代码结构

大家去spring官网上搭建一套springboot的项目,博主就不展示pom的代码了。直接将项目结构展示一下,如下图所示


# EatuulServlet.java。这个是网关的入口,逻辑也十分简单,分为三步

(1)将request,response放入threadlocal中

(2)执行三组过滤器

(3)清除threadlocal中的的环境变量

源码如下

package com.rjzheng.eatuul.http;
import java.io.IOException;
import javax.servlet.ServletException;
import javax.servlet.annotation.WebServlet;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
@WebServlet(name = "eatuul", urlPatterns = "/*")
public class EatuulServlet extends HttpServlet {
 private EatRunner eatRunner = new EatRunner();
 @Override
 public void service(HttpServletRequest req, HttpServletResponse resp)
 throws ServletException, IOException {
 //将request,和response放入上下文对象中
 eatRunner.init(req, resp);
 try {
 //执行前置过滤
 eatRunner.preRoute();
 //执行过滤
 eatRunner.route();
 //执行后置过滤
 eatRunner.postRoute();
 } catch (Throwable e) {
 RequestContext.getCurrentContext().getResponse()
 .sendError(HttpServletResponse.SC_NOT_FOUND, e.getMessage());
 } finally {
 //清除变量
 RequestContext.getCurrentContext().unset();
 }
 }
}

EatuulRunner.java。

这个是具体的执行器。需要说明一下,在Zuul中,ZuulRunner在获取具体有哪些过滤器的时候,有一个FileLoader可以动态读取配置加载。博主在实现我们自己的EatuulRunner时候,略去动态读取的过程,直接静态写死。

源码如下

package com.rjzheng.eatuul.http;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import com.rjzheng.eatuul.filter.EatuulFilter;
import com.rjzheng.eatuul.filter.post.SendResponseFilter;
import com.rjzheng.eatuul.filter.pre.RequestWrapperFilter;
import com.rjzheng.eatuul.filter.route.RoutingFilter;
public class EatRunner {
 //静态写死过滤器
 private ConcurrentHashMap<String, List<EatuulFilter>> hashFiltersByType = new ConcurrentHashMap<String, List<EatuulFilter>>(){{ 
 put("pre",new ArrayList<EatuulFilter>(){{
 add(new RequestWrapperFilter());
 }});
 put("route",new ArrayList<EatuulFilter>(){{
 add(new RoutingFilter());
 }});
 put("post",new ArrayList<EatuulFilter>(){{
 add(new SendResponseFilter());
 }});
 }};
 
 public void init(HttpServletRequest req, HttpServletResponse resp) {
 RequestContext ctx = RequestContext.getCurrentContext();
 ctx.setRequest(req);
 ctx.setResponse(resp);
 }
 public void preRoute() throws Throwable {
 runFilters("pre"); 
 }
 public void route() throws Throwable{
 runFilters("route"); 
 }
 public void postRoute() throws Throwable{
 runFilters("post");
 }
 
 public void runFilters(String sType) throws Throwable {
 List<EatuulFilter> list = this.hashFiltersByType.get(sType);
 if (list != null) {
 for (int i = 0; i < list.size(); i++) {
 EatuulFilter zuulFilter = list.get(i);
 zuulFilter.run();
 }
 }
 }
}

EatuulFilter.java。接下来就是一系列Filter的代码了,先上父类EatuulFilter的源码

package com.rjzheng.eatuul.filter;
public abstract class EatuulFilter {
 abstract public String filterType();
 abstract public int filterOrder();
 abstract public void run();
}

RequestWrapperFilter.java。这个是PreFilter,前置执行过滤器,负责封装请求。步骤如下所示

(1)封装请求头

(2)封装请求体

(3)构造出RestTemplate能识别的RequestEntity

(4)将RequestEntity放入全局threadlocal之中

代码如下所示

package com.rjzheng.eatuul.filter.pre;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Collections;
import java.util.List;
import javax.servlet.http.HttpServletRequest;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.RequestEntity;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StreamUtils;
import com.rjzheng.eatuul.filter.EatuulFilter;
import com.rjzheng.eatuul.http.RequestContext;
public class RequestWrapperFilter extends EatuulFilter{
 @Override
 public String filterType() {
 // TODO Auto-generated method stub
 return "pre";
 }
 @Override
 public int filterOrder() {
 // TODO Auto-generated method stub
 return -1;
 }
 @Override
 public void run() {
 String rootURL = "http://localhost:9090";
 RequestContext ctx =RequestContext.getCurrentContext();
 HttpServletRequest servletRequest = ctx.getRequest();
 String targetURL = rootURL + servletRequest.getRequestURI();
 RequestEntity<byte[]> requestEntity = null;
 try {
 requestEntity = createRequestEntity(servletRequest, targetURL);
 } catch (Exception e) {
 e.printStackTrace();
 }
 //4、将requestEntity放入全局threadlocal之中
 ctx.setRequestEntity(requestEntity);
 }
 
 private RequestEntity createRequestEntity(HttpServletRequest request,String url) throws URISyntaxException, IOException {
 String method = request.getMethod();
 HttpMethod httpMethod = HttpMethod.resolve(method);
 //1、封装请求头
 MultiValueMap<String, String> headers =createRequestHeaders(request);
 //2、封装请求体
 byte[] body = createRequestBody(request);
 //3、构造出RestTemplate能识别的RequestEntity
 RequestEntity requestEntity = new RequestEntity<byte[]>(body,headers,httpMethod, new URI(url));
 return requestEntity;
 }
 
 private byte[] createRequestBody(HttpServletRequest request) throws IOException {
 InputStream inputStream = request.getInputStream();
 return StreamUtils.copyToByteArray(inputStream);
 }
 private MultiValueMap<String, String> createRequestHeaders(HttpServletRequest request) {
 HttpHeaders headers = new HttpHeaders();
 List<String> headerNames = Collections.list(request.getHeaderNames());
 for(String headerName:headerNames) {
 List<String> headerValues = Collections.list(request.getHeaders(headerName));
 for(String headerValue:headerValues) {
 headers.add(headerName, headerValue);
 }
 }
 return headers;
 }
}

RoutingFilter.java。这个是routeFilter,这里我偷懒了,直接做转发请求,并且将返回值ResponseEntity放入全局threadlocal中

package com.rjzheng.eatuul.filter.route;
import org.springframework.http.RequestEntity;
import org.springframework.http.ResponseEntity;
import org.springframework.web.client.RestTemplate;
import com.rjzheng.eatuul.filter.EatuulFilter;
import com.rjzheng.eatuul.http.RequestContext;
public class RoutingFilter extends EatuulFilter{
 @Override
 public String filterType() {
 // TODO Auto-generated method stub
 return "route";
 }
 @Override
 public int filterOrder() {
 // TODO Auto-generated method stub
 return 0;
 }
 
 @Override
 public void run(){
 RequestContext ctx = RequestContext.getCurrentContext();
 RequestEntity requestEntity = ctx.getRequestEntity();
 RestTemplate restTemplate = new RestTemplate();
 ResponseEntity responseEntity = restTemplate.exchange(requestEntity,byte[].class);
 ctx.setResponseEntity(responseEntity);
 }
 
}

SendResponseFilter.java。

这个是postFilters,将ResponseEntity输出即可

package com.rjzheng.eatuul.filter.post;
import java.util.List;
import java.util.Map;
import javax.servlet.ServletOutputStream;
import javax.servlet.http.HttpServletResponse;
import org.springframework.http.HttpHeaders;
import org.springframework.http.ResponseEntity;
import com.rjzheng.eatuul.filter.EatuulFilter;
import com.rjzheng.eatuul.http.RequestContext;
public class SendResponseFilter extends EatuulFilter{
 @Override
 public String filterType() {
 return "post";
 }
 @Override
 public int filterOrder() {
 return 1000;
 }
 @Override
 public void run() {
 try {
 addResponseHeaders();
 writeResponse();
 } catch (Exception e) {
 e.printStackTrace();
 }
 }
 private void addResponseHeaders() {
 RequestContext ctx = RequestContext.getCurrentContext();
 HttpServletResponse servletResponse = ctx.getResponse();
 ResponseEntity responseEntity = ctx.getResponseEntity();
 HttpHeaders httpHeaders = responseEntity.getHeaders();
 for(Map.Entry<String, List<String>> entry:httpHeaders.entrySet()) {
 String headerName = entry.getKey();
 List<String> headerValues = entry.getValue();
 for(String headerValue:headerValues) {
 servletResponse.addHeader(headerName, headerValue);
 }
 }
 }
 private void writeResponse()throws Exception {
 RequestContext ctx = RequestContext.getCurrentContext();
 HttpServletResponse servletResponse = ctx.getResponse();
 if (servletResponse.getCharacterEncoding() == null) { // only set if not set
 servletResponse.setCharacterEncoding("UTF-8");
 }
 ResponseEntity responseEntity = ctx.getResponseEntity();
 if(responseEntity.hasBody()) {
 byte[] body = (byte[]) responseEntity.getBody();
 ServletOutputStream outputStream = servletResponse.getOutputStream();
 outputStream.write(body);
 outputStream.flush();
 }
 }
}

RequestContext.java。

最后是一直在说的全局threadlocal变量

package com.rjzheng.eatuul.http;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.http.RequestEntity;
import org.springframework.http.ResponseEntity;
public class RequestContext extends ConcurrentHashMap<String, Object> {
 protected static Class<? extends RequestContext> contextClass = RequestContext.class;
 protected static final ThreadLocal<? extends RequestContext> threadLocal = new ThreadLocal<RequestContext>() {
 @Override
 protected RequestContext initialValue() {
 try {
 return contextClass.newInstance();
 } catch (Throwable e) {
 throw new RuntimeException(e);
 }
 }
 };
 public static RequestContext getCurrentContext() {
 RequestContext context = threadLocal.get();
 return context;
 }
 public HttpServletRequest getRequest() {
 return (HttpServletRequest) get("request");
 }
 public void setRequest(HttpServletRequest request) {
 put("request", request);
 }
 public HttpServletResponse getResponse() {
 return (HttpServletResponse) get("response");
 }
 public void setResponse(HttpServletResponse response) {
 set("response", response);
 }
 
 public void setRequestEntity(RequestEntity requestEntity){
 set("requestEntity",requestEntity);
 }
 
 public RequestEntity getRequestEntity() {
 return (RequestEntity) get("requestEntity");
 }
 
 public void setResponseEntity(ResponseEntity responseEntity){
 set("responseEntity",responseEntity);
 }
 
 public ResponseEntity getResponseEntity() {
 return (ResponseEntity) get("responseEntity");
 }
 
 public void set(String key, Object value) {
 if (value != null)
 put(key, value);
 else
 remove(key);
 }
 public void unset() {
 threadLocal.remove();
 }
}

# 如何测试?

自己另外起一个server端口为9090如下所示

package com.rjzheng.eatservice;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.builder.SpringApplicationBuilder;
import org.springframework.boot.web.servlet.ServletComponentScan;
import com.rjzheng.eatservice.controller.IndexController;
@SpringBootApplication
@ServletComponentScan(basePackageClasses = IndexController.class)
public class Application {
 public static void main(String[] args) {
 new SpringApplicationBuilder(Application.class).properties("server.port=9090").run(args);
 }
}

再来一个controller

package com.rjzheng.eatservice.controller;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
@RestController
public class IndexController {
 
 @RequestMapping("/index")
 public String index() {
 return "hello!world";
 }
}

然后,你就发现可以从localhost:8080/index进行跳转访问了

结论

本文暂时没有评论,来添加一个吧(●'◡'●)

欢迎 发表评论:

最近发表
标签列表