大廚小鮮——基於Netty自己動手實現RPC框架

老錢發表於2018-04-15

今天我們要來做一道小菜,這道菜就是RPC通訊框架。它使用netty作為原料,fastjson序列化工具作為調料,來實現一個極簡的多執行緒RPC服務框架。

我們暫且命名該RPC框架為rpckids。

大廚小鮮——基於Netty自己動手實現RPC框架

食用指南

在告訴讀者完整的製作菜譜之前,我們先來試試這個小菜怎麼個吃法,好不好吃,是不是吃起來很方便。如果讀者覺得很難吃,那後面的菜譜就沒有多大意義了,何必花心思去學習製作一門誰也不愛吃的大爛菜呢?

例子中我會使用rpckids提供的遠端RPC服務,用於計算斐波那契數和指數,客戶端通過rpckids提供的RPC客戶端向遠端服務傳送引數,並接受返回結果,然後呈現出來。你可以使用rpckids定製任意的業務rpc服務。

大廚小鮮——基於Netty自己動手實現RPC框架

斐波那契數輸入輸出比較簡單,一個Integer,一個Long。 指數輸入有兩個值,輸出除了計算結果外還包含計算耗時,以納秒計算。之所以包含耗時,只是為了呈現一個完整的自定義的輸入和輸出類。

指數服務自定義輸入輸出類

// 指數RPC的輸入
public class ExpRequest {
	private int base;
	private int exp;
    
    // constructor & getter & setter
}

// 指數RPC的輸出
public class ExpResponse {

	private long value;
	private long costInNanos;

	// constructor & getter & setter
}
複製程式碼

斐波那契和指數計算處理

public class FibRequestHandler implements IMessageHandler<Integer> {

	private List<Long> fibs = new ArrayList<>();

	{
		fibs.add(1L); // fib(0) = 1
		fibs.add(1L); // fib(1) = 1
	}

	@Override
	public void handle(ChannelHandlerContext ctx, String requestId, Integer n) {
		for (int i = fibs.size(); i < n + 1; i++) {
			long value = fibs.get(i - 2) + fibs.get(i - 1);
			fibs.add(value);
		}
		// 輸出響應
		ctx.writeAndFlush(new MessageOutput(requestId, "fib_res", fibs.get(n)));
	}

}

public class ExpRequestHandler implements IMessageHandler<ExpRequest> {

	@Override
	public void handle(ChannelHandlerContext ctx, String requestId, ExpRequest message) {
		int base = message.getBase();
		int exp = message.getExp();
		long start = System.nanoTime();
		long res = 1;
		for (int i = 0; i < exp; i++) {
			res *= base;
		}
		long cost = System.nanoTime() - start;
		// 輸出響應
		ctx.writeAndFlush(new MessageOutput(requestId, "exp_res", new ExpResponse(res, cost)));
	}

}
複製程式碼

構建RPC伺服器

RPC服務類要監聽指定IP埠,設定io執行緒數和業務計算執行緒數,然後註冊斐波那契服務輸入類和指數服務輸入類,還有相應的計算處理器。

public class DemoServer {

	public static void main(String[] args) {
		RPCServer server = new RPCServer("localhost", 8888, 2, 16);
		server.service("fib", Integer.class, new FibRequestHandler())
			  .service("exp", ExpRequest.class, new ExpRequestHandler());
		server.start();
	}

}
複製程式碼

構建RPC客戶端

RPC客戶端要連結遠端IP埠,並註冊服務輸出類(RPC響應類),然後分別呼叫20次斐波那契服務和指數服務,輸出結果

public class DemoClient {

	private RPCClient client;

	public DemoClient(RPCClient client) {
		this.client = client;
		// 註冊服務返回型別
		this.client.rpc("fib_res", Long.class).rpc("exp_res", ExpResponse.class);
	}

	public long fib(int n) {
		return (Long) client.send("fib", n);
	}

	public ExpResponse exp(int base, int exp) {
		return (ExpResponse) client.send("exp", new ExpRequest(base, exp));
	}

	public static void main(String[] args) {
		RPCClient client = new RPCClient("localhost", 8888);
		DemoClient demo = new DemoClient(client);
		for (int i = 0; i < 20; i++) {
			System.out.printf("fib(%d) = %d\n", i, demo.fib(i));
		}
		for (int i = 0; i < 20; i++) {
			ExpResponse res = demo.exp(2, i);
			System.out.printf("exp2(%d) = %d cost=%dns\n", i, res.getValue(), res.getCostInNanos());
		}
	}

}
複製程式碼

執行

先執行伺服器,伺服器輸出如下,從日誌中可以看到客戶端連結過來了,然後傳送了一系列訊息,最後關閉連結走了。

server started @ localhost:8888
connection comes
read a message
read a message
...
connection leaves
複製程式碼

再執行客戶端,可以看到一些列的計算結果都成功完成了輸出。

fib(0) = 1
fib(1) = 1
fib(2) = 2
fib(3) = 3
fib(4) = 5
...
exp2(0) = 1 cost=559ns
exp2(1) = 2 cost=495ns
exp2(2) = 4 cost=524ns
exp2(3) = 8 cost=640ns
exp2(4) = 16 cost=711ns
...

複製程式碼

牢騷

本以為是小菜一碟,但是編寫完整的程式碼和文章卻將近花費了一天的時間,深感寫碼要比做菜耗時太多了。因為只是為了教學目的,所以在實現細節上還有好多沒有仔細去雕琢的地方。如果是要做一個開源專案,力求非常完美的話。至少還要考慮一下幾點。

  1. 客戶端連線池
  2. 多服務程式負載均衡
  3. 日誌輸出
  4. 引數校驗,異常處理
  5. 客戶端流量攻擊
  6. 伺服器壓力極限

如果要參考grpc的話,還得實現流式響應處理。如果還要為了節省網路流量的話,又需要在協議上下功夫。這一大堆的問題還是拋給讀者自己思考去吧。

關注公眾號「碼洞」,傳送「RPC」即可獲取以上完整菜譜的GitHub開原始碼連結。讀者有什麼不明白的地方,洞主也會一一解答。

下面我們接著講RPC伺服器和客戶端精細的製作過程

伺服器菜譜

定義訊息輸入輸出格式,訊息型別、訊息唯一ID和訊息的json序列化字串內容。訊息唯一ID是用來客戶端驗證伺服器請求和響應是否匹配。

public class MessageInput {
	private String type;
	private String requestId;
	private String payload;

	public MessageInput(String type, String requestId, String payload) {
		this.type = type;
		this.requestId = requestId;
		this.payload = payload;
	}

	public String getType() {
		return type;
	}

	public String getRequestId() {
		return requestId;
	}
    
    // 因為我們想直接拿到物件,所以要提供物件的型別引數
	public <T> T getPayload(Class<T> clazz) {
		if (payload == null) {
			return null;
		}
		return JSON.parseObject(payload, clazz);
	}

}

public class MessageOutput {

	private String requestId;
	private String type;
	private Object payload;

	public MessageOutput(String requestId, String type, Object payload) {
		this.requestId = requestId;
		this.type = type;
		this.payload = payload;
	}

	public String getType() {
		return this.type;
	}

	public String getRequestId() {
		return requestId;
	}

	public Object getPayload() {
		return payload;
	}

}
複製程式碼

訊息解碼器,使用Netty的ReplayingDecoder實現。簡單起見,這裡沒有使用checkpoint去優化效能了,感興趣的話讀者可以參考一下我之前在公眾號裡發表的相關文章,將checkpoint相關的邏輯自己新增進去。

public class MessageDecoder extends ReplayingDecoder<MessageInput> {

	@Override
	protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
		String requestId = readStr(in);
		String type = readStr(in);
		String content = readStr(in);
		out.add(new MessageInput(type, requestId, content));
	}

	private String readStr(ByteBuf in) {
		// 字串先長度後位元組陣列,統一UTF8編碼
		int len = in.readInt();
		if (len < 0 || len > (1 << 20)) {
			throw new DecoderException("string too long len=" + len);
		}
		byte[] bytes = new byte[len];
		in.readBytes(bytes);
		return new String(bytes, Charsets.UTF8);
	}

}
複製程式碼

訊息處理器介面,每個自定義服務必須實現handle方法

public interface IMessageHandler<T> {

	void handle(ChannelHandlerContext ctx, String requestId, T message);

}

// 找不到型別的訊息統一使用預設處理器處理
public class DefaultHandler implements IMessageHandler<MessageInput> {

	@Override
	public void handle(ChannelHandlerContext ctx, String requesetId, MessageInput input) {
		System.out.println("unrecognized message type=" + input.getType() + " comes");
	}

}
複製程式碼

訊息型別註冊中心和訊息處理器註冊中心,都是用靜態欄位和方法,其實也是為了圖方便,寫成非靜態的可能會優雅一些。

public class MessageRegistry {
	private static Map<String, Class<?>> clazzes = new HashMap<>();

	public static void register(String type, Class<?> clazz) {
		clazzes.put(type, clazz);
	}

	public static Class<?> get(String type) {
		return clazzes.get(type);
	}
}

public class MessageHandlers {

	private static Map<String, IMessageHandler<?>> handlers = new HashMap<>();
	public static DefaultHandler defaultHandler = new DefaultHandler();

	public static void register(String type, IMessageHandler<?> handler) {
		handlers.put(type, handler);
	}

	public static IMessageHandler<?> get(String type) {
		IMessageHandler<?> handler = handlers.get(type);
		return handler;
	}

}
複製程式碼

響應訊息的編碼器比較簡單

@Sharable
public class MessageEncoder extends MessageToMessageEncoder<MessageOutput> {

	@Override
	protected void encode(ChannelHandlerContext ctx, MessageOutput msg, List<Object> out) throws Exception {
		ByteBuf buf = PooledByteBufAllocator.DEFAULT.directBuffer();
		writeStr(buf, msg.getRequestId());
		writeStr(buf, msg.getType());
		writeStr(buf, JSON.toJSONString(msg.getPayload()));
		out.add(buf);
	}

	private void writeStr(ByteBuf buf, String s) {
		buf.writeInt(s.length());
		buf.writeBytes(s.getBytes(Charsets.UTF8));
	}

}
複製程式碼

好,接下來進入關鍵環節,將上面的小模小塊湊在一起,構建一個完整的RPC伺服器框架,這裡就需要讀者有必須的Netty基礎知識了,需要編寫Netty的事件回撥類和服務構建類。

@Sharable
public class MessageCollector extends ChannelInboundHandlerAdapter {
    // 業務執行緒池
	private ThreadPoolExecutor executor;

	public MessageCollector(int workerThreads) {
		// 業務佇列最大1000,避免堆積
		// 如果子執行緒處理不過來,io執行緒也會加入處理業務邏輯(callerRunsPolicy)
		BlockingQueue<Runnable> queue = new ArrayBlockingQueue<>(1000);
		// 給業務執行緒命名
		ThreadFactory factory = new ThreadFactory() {

			AtomicInteger seq = new AtomicInteger();

			@Override
			public Thread newThread(Runnable r) {
				Thread t = new Thread(r);
				t.setName("rpc-" + seq.getAndIncrement());
				return t;
			}

		};
		// 閒置時間超過30秒的執行緒自動銷燬
		this.executor = new ThreadPoolExecutor(1, workerThreads, 30, TimeUnit.SECONDS, queue, factory,
				new CallerRunsPolicy());
	}

	public void closeGracefully() {
		// 優雅一點關閉,先通知,再等待,最後強制關閉
		this.executor.shutdown();
		try {
			this.executor.awaitTermination(10, TimeUnit.SECONDS);
		} catch (InterruptedException e) {
		}
		this.executor.shutdownNow();
	}

	@Override
	public void channelActive(ChannelHandlerContext ctx) throws Exception {
		// 客戶端來了一個新連結
		System.out.println("connection comes");
	}

	@Override
	public void channelInactive(ChannelHandlerContext ctx) throws Exception {
		// 客戶端走了一個
		System.out.println("connection leaves");
		ctx.close();
	}

	@Override
	public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
		if (msg instanceof MessageInput) {
			System.out.println("read a message");
			// 用業務執行緒池處理訊息
			this.executor.execute(() -> {
				this.handleMessage(ctx, (MessageInput) msg);
			});
		}
	}

	private void handleMessage(ChannelHandlerContext ctx, MessageInput input) {
		// 業務邏輯在這裡
		Class<?> clazz = MessageRegistry.get(input.getType());
		if (clazz == null) {
			// 沒註冊的訊息用預設的處理器處理
			MessageHandlers.defaultHandler.handle(ctx, input.getRequestId(), input);
			return;
		}
		Object o = input.getPayload(clazz);
		// 這裡是小鮮的瑕疵,程式碼外觀上比較難看,但是大廚表示才藝不夠,很無奈
		// 讀者如果感興趣可以自己想辦法解決
		@SuppressWarnings("unchecked")
		IMessageHandler<Object> handler = (IMessageHandler<Object>) MessageHandlers.get(input.getType());
		if (handler != null) {
			handler.handle(ctx, input.getRequestId(), o);
		} else {
			// 用預設的處理器處理吧
			MessageHandlers.defaultHandler.handle(ctx, input.getRequestId(), input);
		}
	}

	@Override
	public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
		// 此處可能因為客戶端機器突發重啟
		// 也可能是客戶端連結閒置時間超時,後面的ReadTimeoutHandler丟擲來的異常
		// 也可能是訊息協議錯誤,序列化異常
		// etc.
		// 不管它,連結統統關閉,反正客戶端具備重連機制
		System.out.println("connection error");
		cause.printStackTrace();
		ctx.close();
	}

}

public class RPCServer {

	private String ip;
	private int port;
	private int ioThreads; // 用來處理網路流的讀寫執行緒
	private int workerThreads; // 用於業務處理的計算執行緒

	public RPCServer(String ip, int port, int ioThreads, int workerThreads) {
		this.ip = ip;
		this.port = port;
		this.ioThreads = ioThreads;
		this.workerThreads = workerThreads;
	}

	private ServerBootstrap bootstrap;
	private EventLoopGroup group;
	private MessageCollector collector;
	private Channel serverChannel;

    // 註冊服務的快捷方式
	public RPCServer service(String type, Class<?> reqClass, IMessageHandler<?> handler) {
		MessageRegistry.register(type, reqClass);
		MessageHandlers.register(type, handler);
		return this;
	}

    // 啟動RPC服務
	public void start() {
		bootstrap = new ServerBootstrap();
		group = new NioEventLoopGroup(ioThreads);
		bootstrap.group(group);
		collector = new MessageCollector(workerThreads);
		MessageEncoder encoder = new MessageEncoder();
		bootstrap.channel(NioServerSocketChannel.class).childHandler(new ChannelInitializer<SocketChannel>() {
			@Override
			public void initChannel(SocketChannel ch) throws Exception {
				ChannelPipeline pipe = ch.pipeline();
				// 如果客戶端60秒沒有任何請求,就關閉客戶端連結
				pipe.addLast(new ReadTimeoutHandler(60));
				// 掛上解碼器
				pipe.addLast(new MessageDecoder());
				// 掛上編碼器
				pipe.addLast(encoder);
				// 將業務處理器放在最後
				pipe.addLast(collector);
			}
		});
		bootstrap.option(ChannelOption.SO_BACKLOG, 100)  // 客戶端套件字接受佇列大小
		         .option(ChannelOption.SO_REUSEADDR, true) // reuse addr,避免埠衝突
		         .option(ChannelOption.TCP_NODELAY, true) // 關閉小流合併,保證訊息的及時性
		         .childOption(ChannelOption.SO_KEEPALIVE, true); // 長時間沒動靜的連結自動關閉
		serverChannel = bootstrap.bind(this.ip, this.port).channel();
		System.out.printf("server started @ %s:%d\n", ip, port);
	}

	public void stop() {
		// 先關閉服務端套件字
		serverChannel.close();
		// 再斬斷訊息來源,停止io執行緒池
		group.shutdownGracefully();
		// 最後停止業務執行緒
		collector.closeGracefully();
	}

}
複製程式碼

上面就是完整的伺服器菜譜,程式碼較多,讀者如果沒有Netty基礎的話,可能會看得眼花繚亂。如果你不常使用JDK的Executors框架,閱讀起來估計也夠嗆。如果讀者需要相關學習資料,可以找我索取。

客戶端菜譜

伺服器使用NIO實現,客戶端也可以使用NIO實現,不過必要性不大,用同步的socket實現也是沒有問題的。更重要的是,同步的程式碼比較簡短,便於理解。所以簡單起見,這裡使用了同步IO。

定義RPC請求物件和響應物件,和伺服器一一對應。

public class RPCRequest {

	private String requestId;
	private String type;
	private Object payload;

	public RPCRequest(String requestId, String type, Object payload) {
		this.requestId = requestId;
		this.type = type;
		this.payload = payload;
	}

	public String getRequestId() {
		return requestId;
	}

	public String getType() {
		return type;
	}

	public Object getPayload() {
		return payload;
	}

}

public class RPCResponse {

	private String requestId;
	private String type;
	private Object payload;

	public RPCResponse(String requestId, String type, Object payload) {
		this.requestId = requestId;
		this.type = type;
		this.payload = payload;
	}

	public String getRequestId() {
		return requestId;
	}

	public void setRequestId(String requestId) {
		this.requestId = requestId;
	}

	public String getType() {
		return type;
	}

	public void setType(String type) {
		this.type = type;
	}

	public Object getPayload() {
		return payload;
	}

	public void setPayload(Object payload) {
		this.payload = payload;
	}

}
複製程式碼

定義客戶端異常,用於統一丟擲RPC錯誤

public class RPCException extends RuntimeException {

	private static final long serialVersionUID = 1L;

	public RPCException(String message, Throwable cause) {
		super(message, cause);
	}

	public RPCException(String message) {
		super(message);
	}
	
	public RPCException(Throwable cause) {
		super(cause);
	}

}
複製程式碼

請求ID生成器,簡單的UUID64

public class RequestId {

	public static String next() {
		return UUID.randomUUID().toString();
	}

}
複製程式碼

響應型別註冊中心,和伺服器對應

public class ResponseRegistry {
	private static Map<String, Class<?>> clazzes = new HashMap<>();

	public static void register(String type, Class<?> clazz) {
		clazzes.put(type, clazz);
	}

	public static Class<?> get(String type) {
		return clazzes.get(type);
	}
}
複製程式碼

好,接下來進入客戶端的關鍵環節,連結管理、讀寫訊息、連結重連都在這裡

public class RPCClient {

	private String ip;
	private int port;
	private Socket sock;
	private DataInputStream input;
	private OutputStream output;

	public RPCClient(String ip, int port) {
		this.ip = ip;
		this.port = port;
	}

	public void connect() throws IOException {
		SocketAddress addr = new InetSocketAddress(ip, port);
		sock = new Socket();
		sock.connect(addr, 5000); // 5s超時
		input = new DataInputStream(sock.getInputStream());
		output = sock.getOutputStream();
	}

	public void close() {
		// 關閉連結
		try {
			sock.close();
			sock = null;
			input = null;
			output = null;
		} catch (IOException e) {
		}
	}

	public Object send(String type, Object payload) {
		// 普通rpc請求,正常獲取響應
		try {
			return this.sendInternal(type, payload, false);
		} catch (IOException e) {
			throw new RPCException(e);
		}
	}

	public RPCClient rpc(String type, Class<?> clazz) {
		// rpc響應型別註冊快捷入口
		ResponseRegistry.register(type, clazz);
		return this;
	}

	public void cast(String type, Object payload) {
		// 單向訊息,伺服器不得返回結果
		try {
			this.sendInternal(type, payload, true);
		} catch (IOException e) {
			throw new RPCException(e);
		}
	}

	private Object sendInternal(String type, Object payload, boolean cast) throws IOException {
		if (output == null) {
			connect();
		}
		String requestId = RequestId.next();
		ByteArrayOutputStream bytes = new ByteArrayOutputStream();
		DataOutputStream buf = new DataOutputStream(bytes);
		writeStr(buf, requestId);
		writeStr(buf, type);
		writeStr(buf, JSON.toJSONString(payload));
		buf.flush();
		byte[] fullLoad = bytes.toByteArray();
		try {
			// 傳送請求
			output.write(fullLoad);
		} catch (IOException e) {
			// 網路異常要重連
			close();
			connect();
			output.write(fullLoad);
		}
		if (!cast) {
			// RPC普通請求,要立即獲取響應
			String reqId = readStr();
			// 校驗請求ID是否匹配
			if (!requestId.equals(reqId)) {
				close();
				throw new RPCException("request id mismatch");
			}
			String typ = readStr();
			Class<?> clazz = ResponseRegistry.get(typ);
			// 響應型別必須提前註冊
			if (clazz == null) {
				throw new RPCException("unrecognized rpc response type=" + typ);
			}
			// 反序列化json串
			String payld = readStr();
			Object res = JSON.parseObject(payld, clazz);
			return res;
		}
		return null;
	}

	private String readStr() throws IOException {
		int len = input.readInt();
		byte[] bytes = new byte[len];
		input.readFully(bytes);
		return new String(bytes, Charsets.UTF8);
	}

	private void writeStr(DataOutputStream out, String s) throws IOException {
		out.writeInt(s.length());
		out.write(s.getBytes(Charsets.UTF8));
	}
}
複製程式碼

牢騷重提

本以為是小菜一碟,但是編寫完整的程式碼和文章卻將近花費了一天的時間,深感寫碼要比做菜耗時太多了。因為只是為了教學目的,所以在實現細節上還有好多沒有仔細去雕琢的地方。如果是要做一個開源專案,力求非常完美的話。至少還要考慮一下幾點。

  1. 客戶端連線池
  2. 多服務程式負載均衡
  3. 日誌輸出
  4. 引數校驗,異常處理
  5. 客戶端流量攻擊
  6. 伺服器壓力極限

如果要參考grpc的話,還得實現流式響應處理。如果還要為了節省網路流量的話,又需要在協議上下功夫。這一大堆的問題還是拋給讀者自己思考去吧。

關注公眾號「碼洞」,傳送「RPC」即可獲取以上完整菜譜的GitHub開原始碼連結。讀者有什麼不明白的地方,洞主也會一一解答。

相關文章