`
huangshihang
  • 浏览: 11577 次
社区版块
存档分类
最新评论

Hadoop RPC Client端的简单实现

阅读更多

 

HadoopRPC代码中,Client负责维护客户端与服务器的连接,连接负责将客户端的请求发送到服务器端并接受服务器端的返回结果。

 

Client的内部对象关系如下:

 

1)一个Client对象维护着多个与服务器的连接;

 

2ConnectionsConnection集合,每个ConnectionConnectionId标识,ConnectionId中包含了Socket连接的服务器端口地址;

 

3Connection维护与服务器的连接,发送和接受数据,在Connection中存放了对服务器的每次请求Call,请求发起时,将Call加入Connection中,返回后从Connection 中删除。

(Hadoop的RPC中Client代码考虑的很细致,文中代码为删减后细节的代码)

(1)Call的代码如下:

 

 

static class Call{
        final int id;                         //标识Call
        final Writable rpcRequest;      //请求
        Writable rpcResponse;          //返回结果
        boolean done;                     //接收返回结果标志

        public Call(Writable param){
            final Integer id = callId.get();
            if(id == null){
                this.id = nextCallId();
            }else{
                callId.set(null);
                this.id = id;
            }
            this.rpcRequest = param;
        }

        public synchronized void callCompleted(){        //接收返回结果后将标志置为true,唤醒挂起的线程
            done = true;
            notify();
        }

        public synchronized void setRpcResponse(Writable rpcResponse){
            this.rpcResponse = rpcResponse;
            callCompleted();
        }
    }

 (2)Connect继承线程Thread类,在初始化后启动,不断的查看是否有返回结果,又返回结果则找到相应的Call

private class Connection extends Thread{
        private InetSocketAddress server;            //socket地址
        private final ConnectionId remoteId;
        private Hashtable<Integer, Call> calls = new Hashtable<Integer, Call>();    //Call collections
        private Socket socket = null;                    
        private DataInputStream in;
        private DataOutputStream out;
        private final Object sendRpcRequestLock = new Object();               //并发控制锁

        public Connection(ConnectionId remoteId){
            this.remoteId = remoteId;
            server = remoteId.getAddress();
        }

        private synchronized boolean addCall(Call call){                //将请求加入HashTable中
            calls.put(call.id, call);
            notify();
            return true;
        }

        private synchronized void setUpConnection(){
            try {
                this.socket = socketFactory.createSocket();
                socket.connect(server);
            } catch (IOException e) {
                e.printStackTrace();
            }
        }

        private synchronized void setUpIOStreams(){                          //建立socket连接并打开输入输出流
            if(socket != null) {
                return;
            }
            System.out.println("connect the socket and create input and output stream");
            setUpConnection();
            try {
                InputStream inputStream = socket.getInputStream();
                OutputStream outputStream = socket.getOutputStream();
                this.in = new DataInputStream(new BufferedInputStream(inputStream));
                this.out = new DataOutputStream(new BufferedOutputStream(outputStream));
            } catch (IOException e) {
                e.printStackTrace();
            }
            start();
        }

        private void closeConnection(){
            if(socket == null){
                return;
            }
            try {
                socket.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
            socket = null;
        }

        public void sendRPCRequest(Call call) throws IOException {         //发送请求,请求中包括callid和request,callid在返回结果时用到
            final DataOutputBuffer d = new DataOutputBuffer();
            System.out.println("prepare to write the data of the call.........");
            d.writeInt(call.id);
            call.rpcRequest.write(d);
            synchronized(sendRpcRequestLock){
                Future<?> senderFuture = SEND_PARAMS_EXECUTOR.submit(new Runnable() {
                    @Override
                    public void run() {
                        try {
                            synchronized (Connection.this.out) {
                                byte[] data = d.getData();
                                int totalLength = d.getLength();
                                out.write(data, 0, totalLength);
                                out.flush();
                            }
                        }
                        catch (IOException e) {
                            e.printStackTrace();
                        }finally {
                            try {
                                d.close();
                            } catch (IOException e) {
                                e.printStackTrace();
                            }
                        }
                    }
                });

                try {
                    senderFuture.get();
                } catch (InterruptedException e) {
                    e.printStackTrace();
                } catch (ExecutionException e) {
                    e.printStackTrace();
                }

            }
        }

        private void receiveRpcResponse() throws NoSuchMethodException, InvocationTargetException {                          //接收返回结果
            try {
                try {                                                //sleep()是为了测试,可以删除
                    sleep(500);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
                Integer callId = in.readInt();
                Call call = calls.get(callId);
                Writable value = valueClass.getConstructor(Method.class,         Object[].class).newInstance(Client.class.getMethod("call", Writable.class, Client.ConnectionId.class), new Object[]{});                      //这里的返回结果的类类型为自定义的类,实现Hadoop io的writable,RPC包括方法部分和参数部分,构造函数需要这两个参数
                value.readFields(in);
                calls.remove(callId);
                call.setRpcResponse(value);                         //返回结果时设置done参数并唤醒线程
                System.out.println("remove the call and the calls:" + calls.size() + ",receive the response:" + value);
            } catch (IOException e) {
                e.printStackTrace();
            } catch (InstantiationException e) {
                e.printStackTrace();
            } catch (IllegalAccessException e) {
                e.printStackTrace();
            }
        }

        private synchronized boolean waitForWork(){
            if(calls.isEmpty()){
                try {
                    wait(100);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }
            if(!calls.isEmpty()){
                return true;
            }else {
                return false;
            }
        }

        private synchronized void close(){
            connections.remove(remoteId);
            try {
                in.close();
                out.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
            closeConnection();
            cleanUpCalls();
        }

        private void cleanUpCalls(){
            Iterator<Entry<Integer, Call>> itor = calls.entrySet().iterator();
            while (itor.hasNext()){
                itor.remove();
            }
        }

        public void run(){
            while(waitForWork()){
                System.out.println("prepare to accept the response...............");
                try {
                    receiveRpcResponse();
                } catch (NoSuchMethodException e) {
                    e.printStackTrace();
                } catch (InvocationTargetException e) {
                    e.printStackTrace();
                }
            }
            close();
        }

    }

 (3)ConnectionId类:

public static class ConnectionId{
        InetSocketAddress address;

        ConnectionId(InetSocketAddress address){
            this.address = address;
        }

        InetSocketAddress getAddress(){ return this.address; }

    }

 (4)Client类的成员:

 

private static final AtomicInteger callIdCounter = new AtomicInteger();
    private static final ThreadLocal<Integer> callId = new ThreadLocal<Integer>();
    private Hashtable<ConnectionId, Connection> connections = new Hashtable<ConnectionId, Connection>();
    private SocketFactory socketFactory;
    private static final ExecutorService SEND_PARAMS_EXECUTOR = Executors.newCachedThreadPool(
            new ThreadFactoryBuilder().setDaemon(true).setNameFormat("IPC Params sending Thread #%d").build()
    );

    private Class<? extends Writable> valueClass;

 (5)Client类方法:

 

public Writable call(Writable rpcRequest, ConnectionId remoteId){      //发送服务器请求时调用Client的call方法
        final Call call = new Call(rpcRequest);
        Connection connection = getConnection(remoteId, call);
        try {
            connection.sendRPCRequest(call);
        } catch (IOException e) {
            e.printStackTrace();
        }
        synchronized (call){
            while (!call.done){
                try {
                    System.out.println("waiting for the complete..........");
                    call.wait();
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }
        }

        return call.rpcResponse;
    }

    private Connection getConnection(ConnectionId remoteId, Call call){
        Connection connection;
        do {
            synchronized (connections) {
                connection = connections.get(remoteId);
                if (connection == null) {
                    connection = new Connection(remoteId);
                    connections.put(remoteId, connection);
                }
            }
        }while(!connection.addCall(call));
        System.out.println("create or already have connection in connections:" + connection.getName() + "--" + call.id + ",calls size:" + connection.calls.size());
        connection.setUpIOStreams();
        return connection;
    }

    public static int nextCallId(){
        return callIdCounter.getAndIncrement() & 0x7FFFFFFF;
    }

 

(6)测试,线程模拟服务接收请求并返回结果,这里请求和返回都是Invocation对象

public static void main(String[] args){
        Thread thread = new Thread(new Runnable() {
            @Override
            public void run() {
                ServerSocket serverSocket;
                boolean flag = true;
                try {
                    serverSocket = new ServerSocket(8088);
                    while (flag) {
                        try {
                            sleep(1000);
                        } catch (InterruptedException e) {
                            e.printStackTrace();
                        }
                        Socket socket = serverSocket.accept();
                        System.out.println("accept socket at port:8088.............");
                        DataOutputStream out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream()));
                        DataInputStream in = new DataInputStream(new BufferedInputStream(socket.getInputStream()));
                        try {
                            Invocation invocation = new Invocation(Client.class.getMethod("call", Writable.class, Client.ConnectionId.class), new Object[]{});
                            System.out.println("prepare to read information from in of socket:<<<<<<<<<<<");
                            int id = in.readInt();
                            System.out.println("read int :" + id);
                            invocation.readFields(in);
                            System.out.println("read invocation :" + invocation);
                            out.writeInt(id);
                            invocation.write(out);
                            out.flush();
                            System.out.println("write procession is over>>>>>>>>>>>>>" + invocation);
                        } catch (NoSuchMethodException e) {
                            e.printStackTrace();
                        }
                        out.close();
                        in.close();
                        flag = false;
                    }
                } catch (IOException e) {
                    e.printStackTrace();
                }


            }
        });

        thread.start();


        Client.ConnectionId id = new Client.ConnectionId(new InetSocketAddress("127.0.0.1", 8088));

        Client client = new Client(Invocation.class);
        try {
            try {
                sleep(500);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            Invocation invocation = (Invocation)client.call(new Invocation(Client.class.getMethod("call", Writable.class, Client.ConnectionId.class), new Object[]{}), id);
        } catch (NoSuchMethodException e) {
            e.printStackTrace();
        }
    }

 (7)自定义的Invocation类

public class Invocation implements Writable {
    private String methodName;
    private Class<?>[] parameterClasses;
    private Object[] parameters;
    private Configuration configure;

    public Object[] getParameters() {
        return parameters;
    }

    public Class<?>[] getParamterClasses() {
        return parameterClasses;
    }

    public String getMethodName() {
        return methodName;
    }
    public Configuration getConfigure() {
        return configure;
    }

    public void setConfigure(Configuration configure) {
        this.configure = configure;
    }

    public Invocation(){

    }

    public Invocation(Method method, Object[] parameters){
        this.methodName = method.getName();
        this.parameterClasses = method.getParameterTypes();
        this.parameters = parameters;
    }

    @Override
    public void write(DataOutput dataOutput) throws IOException {
        UTF8.writeString(dataOutput, methodName);
        for(int i = 0; i < parameters.length; i++) {
            ObjectWritable.writeObject(dataOutput, parameters[i], parameterClasses[i], this.configure, true);
        }
    }

    @Override
    public void readFields(DataInput dataInput) throws IOException {
        methodName = UTF8.readString(dataInput);
        parameterClasses = new Class[parameters.length];
        ObjectWritable objectWritable = new ObjectWritable();
        for (int i = 0; i < parameters.length; i++) {
            parameters[i] =
                    ObjectWritable.readObject(dataInput, objectWritable, this.configure);
            parameterClasses[i] = objectWritable.getDeclaredClass();
        }
    }
}

 

 

分享到:
评论

相关推荐

Global site tag (gtag.js) - Google Analytics