欢迎访问 生活随笔!

生活随笔

当前位置: 首页 > 编程资源 > 编程问答 >内容正文

编程问答

手撕 RPC 2

发布时间:2024/9/30 编程问答 42 豆豆
生活随笔 收集整理的这篇文章主要介绍了 手撕 RPC 2 小编觉得挺不错的,现在分享给大家,帮大家做个参考.

把单个客户端改成多个呢?

在成功连接几个client后报错了,是拆解包的时候出错,为什么呢?看一下多个客户端连接时候的 netty 的模型:


因为服务程序什么时候从内核的buffer里读走数据跟客户端往buffer里写数据是两个独立的动作,所以在多客户端(多个socket)的场景下是不能整齐的读到正确的包,所以报错解码错误,IO是双向的,服务端解码错误的问题在客户端也一样,怎么处理?

处理通信流的解码问题

出现这种问题底层的原因是

因为多个socket往一个buffer里写,所以要保证程序读取的时候每次都能跟包的头部(header和body的头部)对齐,而且每次readChannel()的之前程序把该次请求的包发完整。强大的 netty 给了我们 ByteToMessageDecoder ,在 pipeline 的业务事件之前加上解码事件就可以了。

另外解决一些问题

1.服务器使用20个线程处理 listen socket 连接和 IO socket。
2.多个 client 连接,维护一个线程池来管理这些连接。
3.使用 netty 的 ByteToMessageDecoder 解决解码不正确问题。
4.使用 CompletableFuture 来获取客户端调用方法的返回。

package rpc;import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.ServerBootstrap; import io.netty.buffer.ByteBuf; import io.netty.buffer.PooledByteBufAllocator; import io.netty.channel.*; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.handler.codec.ByteToMessageDecoder; import org.junit.Test;import java.io.*; import java.lang.reflect.InvocationHandler; import java.lang.reflect.Method; import java.lang.reflect.Proxy; import java.net.InetSocketAddress; import java.util.List; import java.util.Random; import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger;/*** 1.假设一个需求,写一个 rpc* 2.来回通信,连接数量,拆包* 3.动态代理,序列化,协议封装* 4.连接池*/ public class MyRpcTest {/*** 模拟 consumer*/@Testpublic void startServer() { // 启动20个 selector,每个线程都维护列一个 IO(注意不是 listen) 的 socket 的序列,如果有socket 进来了,就处理,这是 netty 内部自己实现的NioEventLoopGroup eventExecutors = new NioEventLoopGroup(20);NioEventLoopGroup worker = eventExecutors;ServerBootstrap serverBootstrap = new ServerBootstrap();ChannelFuture localhost = serverBootstrap.group(eventExecutors, worker).channel(NioServerSocketChannel.class).childHandler(new ChannelInitializer<NioSocketChannel>() {@Overrideprotected void initChannel(NioSocketChannel nioSocketChannel) {ChannelPipeline pipeline = nioSocketChannel.pipeline();pipeline.addLast(new MyDecoder());pipeline.addLast(new ServerRequestHandler());}}).bind(new InetSocketAddress("localhost", 9090));try {localhost.sync().channel().closeFuture().sync();} catch (InterruptedException e) {e.printStackTrace();}}@Testpublic void get() {// 开启一个线程模拟服务器new Thread(() -> {startServer();}).start();try {Thread.sleep(2000);} catch (InterruptedException e) {e.printStackTrace();}System.out.println("server start ....");// 获取服务AtomicInteger atomicInteger = new AtomicInteger();int size = 30;Thread[] threads = new Thread[size];for (int i = 0; i < size; i++) {threads[i] = new Thread(() -> {Car car = proxyGet(Car.class);int args = atomicInteger.incrementAndGet();String s = car.race("client: " + args);// 打印传参看是不是服务能对应上System.out.println("client got args: " + s + " --- " + args);});}for (Thread thread : threads) {thread.start();}try {System.in.read();} catch (IOException e) {e.printStackTrace();}}private static <T> T proxyGet(Class<T> interfaceInfo) {// 实现各个版本的动态代理ClassLoader classLoader = interfaceInfo.getClassLoader();Class<?>[] interfaces = {interfaceInfo};// 用 jdk 的动态代理实现return (T) Proxy.newProxyInstance(classLoader, interfaces, new InvocationHandler() {@Overridepublic Object invoke(Object proxy, Method method, Object[] args) throws Throwable {// 1.调用,服务、方法、参数,封装成 contentString name = interfaceInfo.getName(); // 服务名String methodName = method.getName(); // 方法名Class<?>[] parameterTypes = method.getParameterTypes(); // 方法的返回类型// 2.把调用服务的信息封装成一个可以序列化的对象// 先封装 bodyMyContent content = new MyContent();content.setName(name);content.setMethodName(methodName);content.setParameterTypes(parameterTypes);content.setArgs(args);// 把 content 做成字节数组准备写出去ByteArrayOutputStream bos = new ByteArrayOutputStream();ObjectOutputStream oos = new ObjectOutputStream(bos);oos.writeObject(content);byte[] msgBody = bos.toByteArray();// 再封装 header,header 需要 body 信息,MyHeader myHeader = createHeader(msgBody);bos.reset();oos = new ObjectOutputStream(bos);oos.writeObject(myHeader);byte[] msgHeader = bos.toByteArray();// 3.服务准备好了,接下来准备连接,模拟一个 size 是1的连接池ClientFactory factory = ClientFactory.getFactory();NioSocketChannel client = factory.getClient(new InetSocketAddress("localhost", 9090));ByteBuf byteBuf = PooledByteBufAllocator.DEFAULT.directBuffer(msgHeader.length + msgBody.length);// 用这个控制结果返回,其中 complete() 方法异步计算结果,get()方法阻塞获取结果CompletableFuture<String> future = new CompletableFuture<>();ResponseMappingHandler.addCallBack(myHeader.getRequestId(), future);byteBuf.writeBytes(msgHeader);byteBuf.writeBytes(msgBody);ChannelFuture channelFuture = client.writeAndFlush(byteBuf);channelFuture.sync();// 把结果返回给客户端啊return future.get();}});}static MyHeader createHeader(byte[] msgBytes) {MyHeader header = new MyHeader();int size = msgBytes.length;// 用16进制的,32 位可以做很多事情int f = 0x14141414;long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits());header.setFlag(f);header.setRequestId(requestId);header.setDataLen(size);return header;} }/*** 模拟服务*/ interface Car {String race(String msg); }/*** 头部定义三个标志* 1.方法的标记,用32位的位标记* 2.请求的 id* 3.请求体的长度*/ class MyHeader implements Serializable {int flag;long requestId;long dataLen;public int getFlag() {return flag;}public void setFlag(int flag) {this.flag = flag;}public long getRequestId() {return requestId;}public void setRequestId(long requestId) {this.requestId = requestId;}public long getDataLen() {return dataLen;}public void setDataLen(long dataLen) {this.dataLen = dataLen;} }/*** 模拟请求体*/ class MyContent implements Serializable {// 服务名String name;// 方法名String methodName;// 返回值类型Class<?>[] parameterTypes;// 参数Object[] args;String res;public String getRes() {return res;}public void setRes(String res) {this.res = res;}public String getName() {return name;}public void setName(String name) {this.name = name;}public String getMethodName() {return methodName;}public void setMethodName(String methodName) {this.methodName = methodName;}public Class<?>[] getParameterTypes() {return parameterTypes;}public void setParameterTypes(Class<?>[] parameterTypes) {this.parameterTypes = parameterTypes;}public Object[] getArgs() {return args;}public void setArgs(Object[] args) {this.args = args;} }/*** 模拟客户端的创建,用单例*/ class ClientFactory {int pollSize = 50; // 连接池里有50个连接NioEventLoopGroup clientWorker;Random random = new Random();private static final ClientFactory factory;private ClientFactory() {}static {factory = new ClientFactory();}public static ClientFactory getFactory() {return factory;}ConcurrentHashMap<InetSocketAddress, ClientPool> outboxes = new ConcurrentHashMap<InetSocketAddress, ClientPool>();public synchronized NioSocketChannel getClient(InetSocketAddress address) {ClientPool clientPool = outboxes.get(address);if (clientPool == null) {outboxes.putIfAbsent(address, new ClientPool(pollSize));clientPool = outboxes.get(address);}int i = random.nextInt(pollSize);// 如果有就返回if (clientPool.clients[i] != null && clientPool.clients[i].isActive()) {return clientPool.clients[i];}// 没有就创建synchronized (clientPool.locks[i]) {return clientPool.clients[i] = create(address);}}private NioSocketChannel create(InetSocketAddress address) {// 基于 netty 的客户端创建方式clientWorker = new NioEventLoopGroup(1);Bootstrap bs = new Bootstrap();ChannelFuture connect = bs.group(clientWorker).channel(NioSocketChannel.class).handler(new ChannelInitializer<NioSocketChannel>() {@Overrideprotected void initChannel(NioSocketChannel nioSocketChannel) {ChannelPipeline pipeline = nioSocketChannel.pipeline();pipeline.addLast(new MyDecoder());pipeline.addLast(new ClientResponses());}}).connect(address);try {NioSocketChannel client = (NioSocketChannel) connect.sync().channel();return client;} catch (InterruptedException e) {e.printStackTrace();}return null;} }/*** 模拟线连接池*/ class ClientPool {NioSocketChannel[] clients;Object[] locks;ClientPool(int size) {clients = new NioSocketChannel[size];locks = new Object[size];for (int i = 0; i < size; i++) {在这里插入代码片locks[i] = new Object();}} }class ClientResponses extends ChannelInboundHandlerAdapter {@Overridepublic void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {MyDataPackage data = (MyDataPackage) msg;ResponseMappingHandler.runCallBack(data);} }/*** 用于主线程的阻塞的控制*/ class ResponseMappingHandler {static ConcurrentHashMap<Long, CompletableFuture> mapping = new ConcurrentHashMap<>();public static void addCallBack(long requestId, CompletableFuture cb) {mapping.putIfAbsent(requestId, cb);}public static void runCallBack(MyDataPackage data) {mapping.get(data.getHeader().getRequestId()).complete(data.getContent().getRes());removeCallBack(data.getHeader().getRequestId());}public static void removeCallBack(long requestId) {mapping.remove(requestId);} }/*** 服务端注册的事件* 没有具体的业务逻辑,*/ class ServerRequestHandler extends ChannelInboundHandlerAdapter {@Overridepublic void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {MyDataPackage data = (MyDataPackage) msg;String ioThreadName = Thread.currentThread().getName();ctx.executor().parent().next().execute(() -> { // ctx.executor().execute(() -> {String execThreadName = Thread.currentThread().getName();MyContent content = new MyContent();String s = "io thread: " + ioThreadName + " exec thread: " + execThreadName + " from args: " + data.getContent().getArgs()[0];content.setRes(s);byte[] contentByte = MySerializeUtil.serialize(content);MyHeader header = new MyHeader();// 0x14141424 标记是服务端, 客户端是 0x14141414header.setFlag(0x14141424);header.setRequestId(data.getHeader().getRequestId());header.setDataLen(contentByte.length);byte[] headerByte = MySerializeUtil.serialize(header);ByteBuf byteBuf = PooledByteBufAllocator.DEFAULT.directBuffer(headerByte.length + contentByte.length);byteBuf.writeBytes(headerByte);byteBuf.writeBytes(contentByte);ctx.writeAndFlush(byteBuf);});} }/*** 利用 netty 的 ByteToMessageDecoder 进行解码*/ class MyDecoder extends ByteToMessageDecoder {@Overrideprotected void decode(ChannelHandlerContext channelHandlerContext, ByteBuf byteBuf, List<Object> list) throws Exception {// 82 是打断点跟踪得出来的 header 的长度while (byteBuf.readableBytes() >= 82) {byte[] bytes = new byte[82];// 不让指针移动byteBuf.getBytes(byteBuf.readerIndex(), bytes);ByteArrayInputStream bis = new ByteArrayInputStream(bytes);ObjectInputStream objectInputStream = new ObjectInputStream(bis);MyHeader myHeader = (MyHeader) objectInputStream.readObject();if (byteBuf.readableBytes() >= myHeader.getDataLen()) {byteBuf.readBytes(82); // 把指针移动到 content 的头部byte[] data = new byte[(int) myHeader.getDataLen()];byteBuf.readBytes(data);ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(data);ObjectInputStream ois = new ObjectInputStream(byteArrayInputStream);// 因为客户端和服务端都用同一个 decoder,区分一下if (myHeader.getFlag() == 0x14141414) {MyContent myContent = (MyContent) ois.readObject();list.add(new MyDataPackage(myHeader, myContent));} else if (myHeader.getFlag() == 0x14141424) {MyContent myContent = (MyContent) ois.readObject();list.add(new MyDataPackage(myHeader, myContent));}} else {break;}}} }class MyDataPackage {private MyHeader header;private MyContent content;public MyDataPackage(MyHeader myHeader, MyContent myContent) {this.header = myHeader;this.content = myContent;}public MyHeader getHeader() {return header;}public void setHeader(MyHeader header) {this.header = header;}public MyContent getContent() {return content;}public void setContent(MyContent content) {this.content = content;} }

总结

以上是生活随笔为你收集整理的手撕 RPC 2的全部内容,希望文章能够帮你解决所遇到的问题。

如果觉得生活随笔网站内容还不错,欢迎将生活随笔推荐给好友。