RMI基礎

毛利_小五郎發表於2024-08-05

從 IDEA 斷點分析 RMI 通訊原理

1. 流程分析總覽

首先 RMI 有三部分:

·RMI Registry

·RMI Server

·RMI Client

關於流程圖,放在檔案裡面了

2. 建立遠端服務

RMIServer

public class RMIServer {
public static void main(String[] args) throws Exception{
    IRemoteObj remoteObj = new RemoteObjImpl();
//        Registry r = LocateRegistry.createRegistry(1099);
//        r.bind("remoteObj",remoteObj);
}
}

RemoteObjImpl

public class RemoteObjImpl extends UnicastRemoteObject implements IRemoteObj {
public RemoteObjImpl() throws RemoteException{

}

public  String sayHello(String keywords){
    String upKeywords = keywords.toUpperCase();
    System.out.println(upKeywords);
    return upKeywords;
}
}

我們來研究一下,他是怎麼把伺服器發到網路上的,在RMIServerIRemoteObj remoteObj = new RemoteObjImpl();打個斷點,開始除錯,f7進入,然後shift+f7,再次進入。

protected UnicastRemoteObject() throws RemoteException
{
    this(0);
}

RemoteObjImpl 這個類是繼承於 UnicastRemoteObject 的,所以先會到父類的建構函式,父類的建構函式這裡的 port 傳入了 0,它代表一個隨機埠,我們f7接著進去。

protected UnicastRemoteObject(int port) throws RemoteException
{
    this.port = port;
    exportObject((Remote) this, port);
}

然後父類把port賦值為0,遠端服務這裡如果傳入的是 0,它會被髮布到網路上的一個隨機埠,我們可以繼續往下看一看。先 f8 到 exportObject(),再 f7 跳進去看

public static Remote exportObject(Remote obj, int port)
    throws RemoteException
{
    return exportObject(obj, new UnicastServerRef(port));
}

exportObject() 是一個靜態函式,它就是主要負責將遠端服務釋出到網路上

我們來看這個靜態函式,第一個引數是 obj 物件,第二個引數是 new UnicastServerRef(port),第二個引數是用來處理網路請求的。繼續往下面跟,去到了 UnicastServerRef 的建構函式。這裡跟的操作先 f7,然後點選 UnicastServerRef 跟進

public UnicastServerRef(int port) {
    super(new LiveRef(port));
}

跟進去之後UnicastServerRef的建構函式,我們看到它new了一個 LiveRef(port),這個非常重要,它算是一個網路引用的類,跟進this看一看。

public LiveRef(ObjID objID, int port) {
    this(objID, TCPEndpoint.getLocalEndpoint(port), true);
}

第一個引數 ID,第三個引數為 true,所以我們重點關注一下第二個引數。

TCPEndpoint 是一個網路請求的類,我們可以去看一下它的建構函式,傳參進去一個 IP 與一個埠,也就是說傳進去一個 IP 和一個埠,就可以進行網路請求。

public TCPEndpoint(String host, int port) {
    this(host, port, null, null);
}

我們進入LiveRef的建構函式

public LiveRef(ObjID objID, Endpoint endpoint, boolean isLocal) {
    ep = endpoint;
    id = objID;
    this.isLocal = isLocal;
}

這時候我們可以看一下一些賦值,發現 hostport 是賦值到了 endpoint 裡面,而 endpoint 又是被封裝在 LiveRef 裡面的,所以記住資料是在 LiveRef 裡面即可,並且這一LiveRef至始至終只會存在一個。

回到上文那個地方,繼續 f7 進入 super 看一看它的父類 UnicastRef,這裡就證明整個建立遠端服務的過程只會存在一個 LiveRef。

 public UnicastRef(LiveRef liveRef) {
    ref = liveRef;
}

一路 f7 到一個靜態函式 exportObject(),我們後續的操作過程都與 exportObject() 有關,基本都是在呼叫它,這一段不是很重要,一路 f7 就好了。直到此處出現 Stub

public Remote exportObject(Remote impl, Object data,
                           boolean permanent)
    throws RemoteException
{
    Class<?> implClass = impl.getClass();
    Remote stub;

    try {
        stub = Util.createProxy(implClass, getClientRef(), forceStubUse);
    } catch (IllegalArgumentException e) {
        throw new ExportException(
            "remote object implements illegal remote interface", e);
    }
    if (stub instanceof RemoteStub) {
        setSkeleton(impl);
    }

    Target target =
        new Target(impl, this, stub, ref.getObjID(), permanent);
    ref.exportObject(target);
    hashToMethod_Map = hashToMethod_Maps.get(implClass);
    return stub;
}

RMI 先在 Service 的地方,也就是服務端建立一個 Stub,再把 Stub 傳到 RMI Registry 中,最後讓 RMI Client 去獲取 Stub。

我們進去研究一下怎麼建立的

public static Remote createProxy(Class<?> implClass,
                                 RemoteRef clientRef,
                                 boolean forceStubUse)
    throws StubNotFoundException
{
    Class<?> remoteClass;

    try {
        remoteClass = getRemoteClass(implClass);
    } catch (ClassNotFoundException ex ) {
        throw new StubNotFoundException(
            "object does not implement a remote interface: " +
            implClass.getName());
    }

    if (forceStubUse ||
        !(ignoreStubClasses || !stubClassExists(remoteClass)))
    {
        return createStub(remoteClass, clientRef);
    }

    final ClassLoader loader = implClass.getClassLoader();
    final Class<?>[] interfaces = getRemoteInterfaces(implClass);
    final InvocationHandler handler =
        new RemoteObjectInvocationHandler(clientRef);

    /* REMIND: private remote interfaces? */

    try {
        return AccessController.doPrivileged(new PrivilegedAction<Remote>() {
            public Remote run() {
                return (Remote) Proxy.newProxyInstance(loader,
                                                       interfaces,
                                                       handler);
            }});
    } catch (IllegalArgumentException e) {
        throw new StubNotFoundException("unable to create proxy", e);
    }
}

這個判斷暫時不用管,後續我們會碰到,那個時候再講。再往下走,我們可以看到這是很明顯的類載入的地方

  AccessController.doPrivileged(new PrivilegedAction<Remote>() {
            public Remote run() {
                return (Remote) Proxy.newProxyInstance(loader,
                                                       interfaces,
                                                       handler);

第一個引數是 AppClassLoader,第二個引數是一個遠端介面,第三個引數是呼叫處理器,呼叫處理器裡面只有一個 ref,它也是和之前我們看到的 ref 是同一個,建立遠端服務當中永遠只有一個 ref,此處就把動態代理建立好了。

Target target =
        new Target(impl, this, stub, ref.getObjID(), permanent);

繼續 f8,到 Target 這裡,Target 這裡相當於一個總的封裝,將所有用的東西放到 Target 裡面

public Target(Remote impl, Dispatcher disp, Remote stub, ObjID id,
              boolean permanent)
{
    this.weakImpl = new WeakRef(impl, ObjectTable.reapQueue);
    this.disp = disp;
    this.stub = stub;
    this.id = id;
    this.acc = AccessController.getContext();

    
    ClassLoader threadContextLoader =
        Thread.currentThread().getContextClassLoader();
    ClassLoader serverLoader = impl.getClass().getClassLoader();
    if (checkLoaderAncestry(threadContextLoader, serverLoader)) {
        this.ccl = threadContextLoader;
    } else {
        this.ccl = serverLoader;
    }

    this.permanent = permanent;
    if (permanent) {
        pinImpl();
    }
}

Dispatcher disp服務端 Remote stub客戶端的資料

然後f8,到下一句ref.exportObject(target);,我們跟進去看一下它的釋出邏輯是怎麼一回事,一路 f7 到listen

   public void exportObject(Target target) throws RemoteException {
    
    synchronized (this) {
        listen();
        exportCount++;
    }
    boolean ok = false;
    try {
        super.exportObject(target);
        ok = true;
    } finally {
        if (!ok) {
            synchronized (this) {
                decrementExportCount();
            }
        }
    }
}

從這裡開始,第一句語句 listen,真正處理網路請求了跟進去。
先獲取 TCPEndpoint然後我們繼續 f8 往後看,直到 server = ep.newServerSocket();

   ServerSocket newServerSocket() throws IOException {
    if (TCPTransport.tcpLog.isLoggable(Log.VERBOSE)) {
        TCPTransport.tcpLog.log(Log.VERBOSE,
            "creating server socket on " + this);
    }

    RMIServerSocketFactory serverFactory = ssf;
    if (serverFactory == null) {
        serverFactory = chooseFactory();
    }
    ServerSocket server = serverFactory.createServerSocket(listenPort);

    // if we listened on an anonymous port, set the default port
    // (for this socket factory)
    if (listenPort == 0)
        setDefaultPort(server.getLocalPort(), csf, ssf);

    return server;
}

他開了一個socket,已經準備好了,等別人來連線

   if (listenPort == 0)
        setDefaultPort(server.getLocalPort(), csf, ssf);

若前面埠是0,那麼就會給你隨機一個埠

然後接著回到listen,然後到hread t = AccessController.doPrivileged( new NewThreadAction(new AcceptLoop(server), "TCP Accept-" + port, true));這一步

進入AcceptLoop(server)

  private class AcceptLoop implements Runnable {

    private final ServerSocket serverSocket;

   
    private long lastExceptionTime = 0L;
    private int recentExceptionCount;

    AcceptLoop(ServerSocket serverSocket) {
        this.serverSocket = serverSocket;
    }

    public void run() {
        try {
            executeAcceptLoop();
        } finally {
            try {
                
                serverSocket.close();
            } catch (IOException e) {
            }
        }
    }

然後進入executeAcceptLoop();

 private void executeAcceptLoop() {
        if (tcpLog.isLoggable(Log.BRIEF)) {
            tcpLog.log(Log.BRIEF, "listening on port " +
                       getEndpoint().getPort());
        }}

進行連線處理,然後就完成了。

3. 建立註冊中心

建立註冊中心與服務端是獨立的,所以誰先誰後無所謂,本質上是一整個東西。

public class RMIServer {
public static void main(String[] args) throws Exception{
    IRemoteObj remoteObj = new RemoteObjImpl();
    Registry r = LocateRegistry.createRegistry(1099);
    r.bind("remoteObj",remoteObj);
}
}

在第二句打上斷點,然後進入createRegistry.

  public static Registry createRegistry(int port) throws RemoteException {
    return new RegistryImpl(port);
}

然後接著f7,到了RegistryImpl

public RegistryImpl(int port)
    throws RemoteException
{
    if (port == Registry.REGISTRY_PORT && System.getSecurityManager() != null) {
        // grant permission for default port only.
        try {
            AccessController.doPrivileged(new PrivilegedExceptionAction<Void>() {
                public Void run() throws RemoteException {
                    LiveRef lref = new LiveRef(id, port);
                    setup(new UnicastServerRef(lref));
                    return null;
                }
            }, null, new SocketPermission("localhost:"+port, "listen,accept"));
        } catch (PrivilegedActionException pae) {
            throw (RemoteException)pae.getException();
        }
    } else {
        LiveRef lref = new LiveRef(id, port);
        setup(new UnicastServerRef(lref));
    }
}

先判斷 port 是否為註冊中心的 port,以及是否開啟了 SecurityManager,也就是一系列的安全檢查。然後就不會進入,他會進入

        LiveRef lref = new LiveRef(id, port);
        setup(new UnicastServerRef(lref));

LiveRef和上面的是一樣的,就不看了,我們目光轉向setup

 private void setup(UnicastServerRef uref)
    throws RemoteException
{
    ref = uref;
    uref.exportObject(this, null, true);
}

跟進之後發現和之前是一樣的,也是先賦值,然後進行 exportObject() 方法的呼叫。區別在於第三個引數的不同,名為 permanent,第一張是 false,第二張是 true,這代表我們建立註冊中心這個物件,是一個永久物件,而之前遠端物件是一個臨時物件。f7 進到 exportObject,就和釋出遠端物件一樣,到了建立 Stub 的階段。

 public Remote exportObject(Remote impl, Object data,
                           boolean permanent)
    throws RemoteException
{
    Class<?> implClass = impl.getClass();
    Remote stub;

    try {
        stub = Util.createProxy(implClass, getClientRef(), forceStubUse);
    } catch (IllegalArgumentException e) {
        throw new ExportException(
            "remote object implements illegal remote interface", e);
    }
    if (stub instanceof RemoteStub) {
        setSkeleton(impl);
    }

    Target target =
        new Target(impl, this, stub, ref.getObjID(), permanent);
    ref.exportObject(target);
    hashToMethod_Map = hashToMethod_Maps.get(implClass);
    return stub;
}

我們到stub = Util.createProxy(implClass, getClientRef(), forceStubUse);看看

public static Remote createProxy(Class<?> implClass,
                                 RemoteRef clientRef,
                                 boolean forceStubUse)
    throws StubNotFoundException
{
    Class<?> remoteClass;

    try {
        remoteClass = getRemoteClass(implClass);
    } catch (ClassNotFoundException ex ) {
        throw new StubNotFoundException(
            "object does not implement a remote interface: " +
            implClass.getName());
    }

    if (forceStubUse ||
        !(ignoreStubClasses || !stubClassExists(remoteClass)))
    {
        return createStub(remoteClass, clientRef);
    }

    final ClassLoader loader = implClass.getClassLoader();
    final Class<?>[] interfaces = getRemoteInterfaces(implClass);
    final InvocationHandler handler =
        new RemoteObjectInvocationHandler(clientRef);

    /* REMIND: private remote interfaces? */

    try {
        return AccessController.doPrivileged(new PrivilegedAction<Remote>() {
            public Remote run() {
                return (Remote) Proxy.newProxyInstance(loader,
                                                       interfaces,
                                                       handler);
            }});
    } catch (IllegalArgumentException e) {
        throw new StubNotFoundException("unable to create proxy", e);
    }
}

首先這裡要做一個判斷。可以跟進 stubClassExists 進行判斷

   private static boolean stubClassExists(Class<?> remoteClass) {
    if (!withoutStubs.containsKey(remoteClass)) {
        try {
            Class.forName(remoteClass.getName() + "_Stub",
                          false,
                          remoteClass.getClassLoader());
            return true;

        } catch (ClassNotFoundException cnfe) {
            withoutStubs.put(remoteClass, null);
        }
    }
    return false;
}

我們看到這個地方,是判斷是否能獲取到 RegistryImpl_Stub 這個類,換句話說,也就是若 RegistryImpl_Stub 這個類存在,則返回 True,反之 False。我們可以找到 RegistryImpl_Stub 這個類是存在的。
接著我們進入return createStub(remoteClass, clientRef);

  private static RemoteStub createStub(Class<?> remoteClass, RemoteRef ref)
    throws StubNotFoundException
{
    String stubname = remoteClass.getName() + "_Stub";
    try {
        Class<?> stubcl =
            Class.forName(stubname, false, remoteClass.getClassLoader());
        Constructor<?> cons = stubcl.getConstructor(stubConsParamTypes);
        return (RemoteStub) cons.newInstance(new Object[] { ref });

    } catch (ClassNotFoundException e) {
        throw new StubNotFoundException(
            "Stub class not found: " + stubname, e);
    } catch (NoSuchMethodException e) {
        throw new StubNotFoundException(
            "Stub class missing constructor: " + stubname, e);
    } catch (InstantiationException e) {
        throw new StubNotFoundException(
            "Can't create instance of stub class: " + stubname, e);
    } catch (IllegalAccessException e) {
        throw new StubNotFoundException(
            "Stub class constructor not public: " + stubname, e);
    } catch (InvocationTargetException e) {
        throw new StubNotFoundException(
            "Exception creating instance of stub class: " + stubname, e);
    } catch (ClassCastException e) {
        throw new StubNotFoundException(
            "Stub class not instance of RemoteStub: " + stubname, e);
    }
}

這個就是建立一個代理,然後把ref傳進去

    if (stub instanceof RemoteStub) {
        setSkeleton(impl);
    }

繼續往下,如果是服務端定義好的,就呼叫 setSkeleton() 方法,跟進去。

     public void setSkeleton(Remote impl) throws RemoteException {
    if (!withoutSkeletons.containsKey(impl.getClass())) {
        try {
            skel = Util.createSkeleton(impl);
        } catch (SkeletonNotFoundException e) {
            
            withoutSkeletons.put(impl.getClass(), null);
        }
    }
}

然後這裡有一個 createSkeleton() 方法,一看名字就知道是用來建立 Skeleton 的,而 Skeleton 在我們的那幅圖中,作為服務端的代理。

    static Skeleton createSkeleton(Remote object)
    throws SkeletonNotFoundException
{
    Class<?> cl;
    try {
        cl = getRemoteClass(object.getClass());
    } catch (ClassNotFoundException ex ) {
        throw new SkeletonNotFoundException(
            "object does not implement a remote interface: " +
            object.getClass().getName());
    }

    // now try to load the skeleton based ont he name of the class
    String skelname = cl.getName() + "_Skel";
    try {
        Class<?> skelcl = Class.forName(skelname, false, cl.getClassLoader());

        return (Skeleton)skelcl.newInstance();
    } catch (ClassNotFoundException ex) {
        throw new SkeletonNotFoundException("Skeleton class not found: " +
                                            skelname, ex);
    } catch (InstantiationException ex) {
        throw new SkeletonNotFoundException("Can't create skeleton: " +
                                            skelname, ex);
    } catch (IllegalAccessException ex) {
        throw new SkeletonNotFoundException("No public constructor: " +
                                            skelname, ex);
    } catch (ClassCastException ex) {
        throw new SkeletonNotFoundException(
            "Skeleton not of correct class: " + skelname, ex);
    }
}

Skeleton 是用forName()的方式建立的,再往後走,又到了 Target 的地方,Target 部分的作用也與之前一樣,用於儲存封裝的資料,所以這一段和前面一樣,就迅速跳過了

繼續走,到了ref.exportObject(target);
public void exportObject(Target target) throws RemoteException {
ep.exportObject(target);
}
接著f7

 public void exportObject(Target target) throws RemoteException {
    transport.exportObject(target);
}

接著f7

    public void exportObject(Target target) throws RemoteException {
    
    synchronized (this) {
        listen();
        exportCount++;
    }
    boolean ok = false;
    try {
        super.exportObject(target);
        ok = true;
    } finally {
        if (!ok) {
            synchronized (this) {
                decrementExportCount();
            }
        }
    }
}

listen這部分我們很熟悉,讓我們到下面super.exportObject(target);看看

public void exportObject(Target target) throws RemoteException {
    target.setExportedTransport(this);
    ObjectTable.putTarget(target);
}

putTarget() 方法,它會把封裝的資料放進去。

4.繫結

繫結也就是最後一步,bind 操作

public class RMIServer {
public static void main(String[] args) throws Exception{
    IRemoteObj remoteObj = new RemoteObjImpl();
    Registry r = LocateRegistry.createRegistry(1099);
    r.bind("remoteObj",remoteObj);
}
}

斷點下在bind上面

然後我們進去

public void bind(String name, Remote obj)
    throws RemoteException, AlreadyBoundException, AccessException
{
    checkAccess("Registry.bind");
    synchronized (bindings) {
        Remote curr = bindings.get(name);
        if (curr != null)
            throw new AlreadyBoundException(name);
        bindings.put(name, obj);
    }
}

checkAccess("Registry.bind");檢查是不是本地繫結,都會透過的,然後我們接著往下

         Remote curr = bindings.get(name);
        if (curr != null)
            throw new AlreadyBoundException(name);

檢查一下 bindings 這裡面是否有東西,其實bindings就是一個 HashTable。如果裡面有資料的話就丟擲異常。

繼續往前走,就是 bindings.put(name, obj);,就是把 IP 和埠放進去,到此處就結束l了

5.客戶端請求,客戶端呼叫註冊中心

這一部分是存在漏洞的點,原因很簡單,這裡有一些個有問題的反序列化

public class RMIClient {
public static  void main(String[] args) throws Exception{
    Registry registry = LocateRegistry.getRegistry("127.0.0.1",1099);
    IRemoteObj remoteObj = (IRemoteObj) registry.lookup("remoteObj");
    remoteObj.sayHello("hello");
}
}

在第一句下個斷點,進到 getRegistry() 方法裡面

public static Registry getRegistry(String host, int port)
    throws RemoteException
{
    return getRegistry(host, port, null);
}  

接著進去

 public static Registry getRegistry(String host, int port,
                                   RMIClientSocketFactory csf)
    throws RemoteException
{
    Registry registry = null;

    if (port <= 0)
        port = Registry.REGISTRY_PORT;

    if (host == null || host.length() == 0) {
        // If host is blank (as returned by "file:" URL in 1.0.2 used in
        // java.rmi.Naming), try to convert to real local host name so
        // that the RegistryImpl's checkAccess will not fail.
        try {
            host = java.net.InetAddress.getLocalHost().getHostAddress();
        } catch (Exception e) {
            // If that failed, at least try "" (localhost) anyway...
            host = "";
        }
    }
    LiveRef liveRef =
        new LiveRef(new ObjID(ObjID.REGISTRY_ID),
                    new TCPEndpoint(host, port, csf, null),
                    false);
    RemoteRef ref =
        (csf == null) ? new UnicastRef(liveRef) : new UnicastRef2(liveRef);

    return (Registry) Util.createProxy(RegistryImpl.class, ref, false);
}

port傳的是1099host傳的是127.0.0.1,就和之前一樣,新建了一個 Ref,然後把該封裝的都封裝到 Ref 裡面進去,獲取到了註冊中心的 Stub

然後我們進入下一句,查詢遠端物件,IRemoteObj remoteObj = (IRemoteObj) registry.lookup("remoteObj");這裡除錯的話,因為對應的 Java 編譯過的 class 檔案是 1.1 的版本,無法進行打斷點,我們直接看。

   public Remote lookup(String var1) throws AccessException, NotBoundException, RemoteException {
    try {
        RemoteCall var2 = super.ref.newCall(this, operations, 2, 4905912898345647071L);

        try {
            ObjectOutput var3 = var2.getOutputStream();
            var3.writeObject(var1);
        } catch (IOException var18) {
            throw new MarshalException("error marshalling arguments", var18);
        }

        super.ref.invoke(var2);

        Remote var23;
        try {
            ObjectInput var6 = var2.getInputStream();
            var23 = (Remote)var6.readObject();
        } catch (IOException var15) {
            throw new UnmarshalException("error unmarshalling return", var15);
        } catch (ClassNotFoundException var16) {
            throw new UnmarshalException("error unmarshalling return", var16);
        } finally {
            super.ref.done(var2);
        }

        return var23;
    } catch (RuntimeException var19) {
        throw var19;
    } catch (RemoteException var20) {
        throw var20;
    } catch (NotBoundException var21) {
        throw var21;
    } catch (Exception var22) {
        throw new UnexpectedException("undeclared checked exception", var22);
    }
}

lookeup是透過序列化傳進去的,他會在var3.writeObject(var1);反序列化。
然後就進入super.ref.invoke(var2);,然後他會到 void invoke(RemoteCall call) throws Exception;,我們走到他的父類UnicastRefinvoke

  public void invoke(RemoteCall call) throws Exception {
    try {
        clientRefLog.log(Log.VERBOSE, "execute call");

        call.executeCall();

    } catch (RemoteException e) {
        clientRefLog.log(Log.BRIEF, "exception: ", e);
        free(call, false);
        throw e;

    } catch (Error e) {
        clientRefLog.log(Log.BRIEF, "error: ", e);
        free(call, false);
        throw e;

    } catch (RuntimeException e) {
        clientRefLog.log(Log.BRIEF, "exception: ", e);
        free(call, false);
        throw e;

    } catch (Exception e) {
        clientRefLog.log(Log.BRIEF, "exception: ", e);
        free(call, true);
        throw e;
    }
}

invoke() 方法裡面會呼叫 call.executeCall(),它是真正處理網路請求的方法,也就是客戶端的網路請求都是透過這個方法實現的。這個方法後續再細講,我們先進去。

  public void executeCall() throws Exception {
    byte returnType;

    // read result header
    DGCAckHandler ackHandler = null;
    try {
        if (out != null) {
            ackHandler = out.getDGCAckHandler();
        }
        releaseOutputStream();
        DataInputStream rd = new DataInputStream(conn.getInputStream());
        byte op = rd.readByte();
        if (op != TransportConstants.Return) {
            if (Transport.transportLog.isLoggable(Log.BRIEF)) {
                Transport.transportLog.log(Log.BRIEF,
                    "transport return code invalid: " + op);
            }
            throw new UnmarshalException("Transport return code invalid");
        }
        getInputStream();
        returnType = in.readByte();
        in.readID();        // id for DGC acknowledgement
    } catch (UnmarshalException e) {
        throw e;
    } catch (IOException e) {
        throw new UnmarshalException("Error unmarshaling return header",
                                     e);
    } finally {
        if (ackHandler != null) {
            ackHandler.release();
        }
    }

    // read return value
    switch (returnType) {
    case TransportConstants.NormalReturn:
        break;

    case TransportConstants.ExceptionalReturn:
        Object ex;
        try {
            ex = in.readObject();
        } catch (Exception e) {
            throw new UnmarshalException("Error unmarshaling return", e);
        }

        // An exception should have been received,
        // if so throw it, else flag error
        if (ex instanceof Exception) {
            exceptionReceivedFromServer((Exception) ex);
        } else {
            throw new UnmarshalException("Return type not Exception");
        }
        // Exception is thrown before fallthrough can occur
    default:
        if (Transport.transportLog.isLoggable(Log.BRIEF)) {
            Transport.transportLog.log(Log.BRIEF,
                "return code invalid: " + returnType);
        }
        throw new UnmarshalException("Return code invalid");
    }
}

這個invokeexecuteCall()裡面有個攻擊的點就是ex = in.readObject();這一行,他是報錯,然後到Switch找錯誤的地方,然後到這個錯誤,他為了讓報錯感覺完整,他用了流傳遞,然後反序列化出來,我們在註冊中心返回一個惡意流,然後就可以在客戶端返回一個惡意物件,這個漏洞很廣,因為所有網路請求都要到invoke

我們接著在lookup往下走

            RemoteCall var2 = super.ref.newCall(this, operations, 2, 4905912898345647071L);
            ObjectInput var6 = var2.getInputStream();
            var23 = (Remote)var6.readObject();

他先是建立一個建立一個新的遠端呼叫,然後把他變成輸出流,然後反序列化,那麼我們在註冊中心搞一個惡意物件就可以打穿他了

6. 客戶端請求,客戶端請求服務端

主文

public class RMIClient {
public static  void main(String[] args) throws Exception{
    Registry registry = LocateRegistry.getRegistry("127.0.0.1",1099);
    IRemoteObj remoteObj = (IRemoteObj) registry.lookup("remoteObj");
    remoteObj.sayHello("hello");
}
}

我們跟進第三句,要force into。

 public Object invoke(Object proxy, Method method, Object[] args)
    throws Throwable
{
    if (! Proxy.isProxyClass(proxy.getClass())) {
        throw new IllegalArgumentException("not a proxy");
    }

    if (Proxy.getInvocationHandler(proxy) != this) {
        throw new IllegalArgumentException("handler mismatch");
    }

    if (method.getDeclaringClass() == Object.class) {
        return invokeObjectMethod(proxy, method, args);
    } else if ("finalize".equals(method.getName()) && method.getParameterCount() == 0 &&
        !allowFinalizeInvocation) {
        return null; // ignore
    } else {
        return invokeRemoteMethod(proxy, method, args);
    }
}

前面都是catch異常,我們一路f8,到最後一句return invokeRemoteMethod(proxy, method, args);,然後我們進去。

private Object invokeRemoteMethod(Object proxy,
                                  Method method,
                                  Object[] args)
    throws Exception
{
    try {
        if (!(proxy instanceof Remote)) {
            throw new IllegalArgumentException(
                "proxy not Remote instance");
        }
        return ref.invoke((Remote) proxy, method, args,
                          getMethodHash(method));
    } catch (Exception e) {
        if (!(e instanceof RuntimeException)) {
            Class<?> cl = proxy.getClass();
            try {
                method = cl.getMethod(method.getName(),
                                      method.getParameterTypes());
            } catch (NoSuchMethodException nsme) {
                throw (IllegalArgumentException)
                    new IllegalArgumentException().initCause(nsme);
            }
            Class<?> thrownType = e.getClass();
            for (Class<?> declaredType : method.getExceptionTypes()) {
                if (declaredType.isAssignableFrom(thrownType)) {
                    throw e;
                }
            }
            e = new UnexpectedException("unexpected exception", e);
        }
        throw e;
    }
}

我們跟進return ref.invoke((Remote) proxy, method, args,getMethodHash(method));invoke

 public Object invoke(Remote obj,
                 Method method,
                 Object[] params,
                 long opnum)
throws Exception
{
if (clientRefLog.isLoggable(Log.VERBOSE)) {
    clientRefLog.log(Log.VERBOSE, "method: " + method);
}

if (clientCallLog.isLoggable(Log.VERBOSE)) {
    logClientCall(obj, method);
}

Connection conn = ref.getChannel().newConnection();
RemoteCall call = null;
boolean reuse = true;
boolean alreadyFreed = false;

try {
    if (clientRefLog.isLoggable(Log.VERBOSE)) {
        clientRefLog.log(Log.VERBOSE, "opnum = " + opnum);
    }

    call = new StreamRemoteCall(conn, ref.getObjID(), -1, opnum);

    try {
        ObjectOutput out = call.getOutputStream();
        marshalCustomCallData(out);
        Class<?>[] types = method.getParameterTypes();
        for (int i = 0; i < types.length; i++) {
            marshalValue(types[i], params[i], out);
        }
    } catch (IOException e) {
        clientRefLog.log(Log.BRIEF,
            "IOException marshalling arguments: ", e);
        throw new MarshalException("error marshalling arguments", e);
    }

    call.executeCall();

    try {
        Class<?> rtype = method.getReturnType();
        if (rtype == void.class)
            return null;
        ObjectInput in = call.getInputStream();
        Object returnValue = unmarshalValue(rtype, in);
        alreadyFreed = true;
        clientRefLog.log(Log.BRIEF, "free connection (reuse = true)");
        ref.getChannel().free(conn, true);

        return returnValue;

    } catch (IOException e) {
        clientRefLog.log(Log.BRIEF,
                         "IOException unmarshalling return: ", e);
        throw new UnmarshalException("error unmarshalling return", e);
    } catch (ClassNotFoundException e) {
        clientRefLog.log(Log.BRIEF,
            "ClassNotFoundException unmarshalling return: ", e);

        throw new UnmarshalException("error unmarshalling return", e);
    } finally {
        try {
            call.done();
        } catch (IOException e) {
            reuse = false;
        }
    }

} catch (RuntimeException e) {
    if ((call == null) ||
        (((StreamRemoteCall) call).getServerException() != e))
    {
        reuse = false;
    }
    throw e;

} catch (RemoteException e) {
    reuse = false;
    throw e;

} catch (Error e) {
    reuse = false;
    throw e;

} finally {
    if (!alreadyFreed) {
        if (clientRefLog.isLoggable(Log.BRIEF)) {
            clientRefLog.log(Log.BRIEF, "free connection (reuse = " +
                                   reuse + ")");
        }
        ref.getChannel().free(conn, reuse);
    }
}
}

程式碼解釋

這部分程式碼檢查日誌記錄的詳細級別,並記錄要呼叫的方法資訊。

if (clientRefLog.isLoggable(Log.VERBOSE)) {
clientRefLog.log(Log.VERBOSE, "method: " + method);
}

獲取一個新的網路連線,這個連線將用於傳輸方法呼叫的請求和響應。

Connection conn = ref.getChannel().newConnection();

建立一個 StreamRemoteCall 物件,它封裝了遠端呼叫的上下文資訊,包括連線、物件ID和操作編號。

call = new StreamRemoteCall(conn, ref.getObjID(), -1, opnum);

將方法呼叫的引數序列化並寫入輸出流,以便透過網路傳送到遠端伺服器。

ObjectOutput out = call.getOutputStream();
marshalCustomCallData(out);
Class<?>[] types = method.getParameterTypes();
for (int i = 0; i < types.length; i++) {
         marshalValue(types[i], params[i], out);
 }

這個方法執行實際的遠端呼叫,等待伺服器返回結果。

call.executeCall();

從輸入流中讀取並反序列化遠端呼叫的返回值。

  ObjectInput in = call.getInputStream();
  Object returnValue = unmarshalValue(rtype, in);

在完成遠端呼叫後,釋放連線。reuse 變數決定了連線是否可以被重用。

  ref.getChannel().free(conn, reuse);

主文

我們接著走,進入marshlValue

 protected static void marshalValue(Class<?> type, Object value,
                                   ObjectOutput out)
    throws IOException
{
    if (type.isPrimitive()) {
        if (type == int.class) {
            out.writeInt(((Integer) value).intValue());
        } else if (type == boolean.class) {
            out.writeBoolean(((Boolean) value).booleanValue());
        } else if (type == byte.class) {
            out.writeByte(((Byte) value).byteValue());
        } else if (type == char.class) {
            out.writeChar(((Character) value).charValue());
        } else if (type == short.class) {
            out.writeShort(((Short) value).shortValue());
        } else if (type == long.class) {
            out.writeLong(((Long) value).longValue());
        } else if (type == float.class) {
            out.writeFloat(((Float) value).floatValue());
        } else if (type == double.class) {
            out.writeDouble(((Double) value).doubleValue());
        } else {
            throw new Error("Unrecognized primitive type: " + type);
        }
    } else {
        out.writeObject(value);
    }
}

他就是序列化,然後走到call.executeCall();前面說過他是個攻擊的地方,他也叫JRMP協議攻擊然後接著往前走到Object returnValue = unmarshalValue(rtype, in);進去

 protected static Object unmarshalValue(Class<?> type, ObjectInput in)
    throws IOException, ClassNotFoundException
{
    if (type.isPrimitive()) {
        if (type == int.class) {
            return Integer.valueOf(in.readInt());
        } else if (type == boolean.class) {
            return Boolean.valueOf(in.readBoolean());
        } else if (type == byte.class) {
            return Byte.valueOf(in.readByte());
        } else if (type == char.class) {
            return Character.valueOf(in.readChar());
        } else if (type == short.class) {
            return Short.valueOf(in.readShort());
        } else if (type == long.class) {
            return Long.valueOf(in.readLong());
        } else if (type == float.class) {
            return Float.valueOf(in.readFloat());
        } else if (type == double.class) {
            return Double.valueOf(in.readDouble());
        } else {
            throw new Error("Unrecognized primitive type: " + type);
        }
    } else {
        return in.readObject();
    }
}

這裡是反序列化,可以惡意攻擊

7.客戶端打註冊中心

public class RMIServer {
public static void main(String[] args) throws Exception{
    IRemoteObj remoteObj = new RemoteObjImpl();
    Registry r = LocateRegistry.createRegistry(1099);
    r.bind("remoteObj",remoteObj);
}
}

我們打斷點到第二句。
進去,和一開始分析的一樣

public static Registry createRegistry(int port) throws RemoteException {
    return new RegistryImpl(port);
}

然後接著f7進去

public RegistryImpl(int port)
    throws RemoteException
{
    if (port == Registry.REGISTRY_PORT && System.getSecurityManager() != null) {
        // grant permission for default port only.
        try {
            AccessController.doPrivileged(new PrivilegedExceptionAction<Void>() {
                public Void run() throws RemoteException {
                    LiveRef lref = new LiveRef(id, port);
                    setup(new UnicastServerRef(lref));
                    return null;
                }
            }, null, new SocketPermission("localhost:"+port, "listen,accept"));
        } catch (PrivilegedActionException pae) {
            throw (RemoteException)pae.getException();
        }
    } else {
        LiveRef lref = new LiveRef(id, port);
        setup(new UnicastServerRef(lref));
    }
}

然後進setup

private void setup(UnicastServerRef uref)
    throws RemoteException
{
    ref = uref;
    uref.exportObject(this, null, true);
}

然後進exportObject

 public Remote exportObject(Remote impl, Object data,
                           boolean permanent)
    throws RemoteException
{
    Class<?> implClass = impl.getClass();
    Remote stub;

    try {
        stub = Util.createProxy(implClass, getClientRef(), forceStubUse);
    } catch (IllegalArgumentException e) {
        throw new ExportException(
            "remote object implements illegal remote interface", e);
    }
    if (stub instanceof RemoteStub) {
        setSkeleton(impl);
    }

    Target target =
        new Target(impl, this, stub, ref.getObjID(), permanent);
    ref.exportObject(target);
    hashToMethod_Map = hashToMethod_Maps.get(implClass);
    return stub;
}

接著進ref.exportObject(target);

public void exportObject(Target target) throws RemoteException {
    ep.exportObject(target);
}

接著進 ep.exportObject(target);

public void exportObject(Target target) throws RemoteException {
    transport.exportObject(target);
}

接著進transport.exportObject(target);

 public void exportObject(Target target) throws RemoteException {
   
    synchronized (this) {
        listen();
        exportCount++;
    } 
    boolean ok = false;
    try {
        super.exportObject(target);
        ok = true;
    } finally {
        if (!ok) {
            synchronized (this) {
                decrementExportCount();
            }
        }
    }
}

然後我們進listen()

private void listen() throws RemoteException {
    assert Thread.holdsLock(this);
    TCPEndpoint ep = getEndpoint();
    int port = ep.getPort();

    if (server == null) {
        if (tcpLog.isLoggable(Log.BRIEF)) {
            tcpLog.log(Log.BRIEF,
                "(port " + port + ") create server socket");
        }

        try {
            server = ep.newServerSocket();
            Thread t = AccessController.doPrivileged(
                new NewThreadAction(new AcceptLoop(server),
                                    "TCP Accept-" + port, true));
            t.start();
        } catch (java.net.BindException e) {
            throw new ExportException("Port already in use: " + port, e);
        } catch (IOException e) {
            throw new ExportException("Listen failed on port: " + port, e);
        }

    } else {
        // otherwise verify security access to existing server socket
        SecurityManager sm = System.getSecurityManager();
        if (sm != null) {
            sm.checkListen(port);
        }
    }
}

然後我們進Thread t = AccessController.doPrivileged(new NewThreadAction(new AcceptLoop(server),"TCP Accept-" + port, true));AcceptLoop

  private class AcceptLoop implements Runnable {

    private final ServerSocket serverSocket;

    // state for throttling loop on exceptions (local to accept thread)
    private long lastExceptionTime = 0L;
    private int recentExceptionCount;

    AcceptLoop(ServerSocket serverSocket) {
        this.serverSocket = serverSocket;
    }

    public void run() {
        try {
            executeAcceptLoop();
        } finally {
            try {
                serverSocket.close();
            } catch (IOException e) {
            }
        }
    }

然後我們進executeAcceptLoop();

private void executeAcceptLoop() {
        if (tcpLog.isLoggable(Log.BRIEF)) {
            tcpLog.log(Log.BRIEF, "listening on port " +
                       getEndpoint().getPort());
        }

        while (true) {
            Socket socket = null;
            try {
                socket = serverSocket.accept();

                
                InetAddress clientAddr = socket.getInetAddress();
                String clientHost = (clientAddr != null
                                     ? clientAddr.getHostAddress()
                                     : "0.0.0.0");

                
                try {
                    connectionThreadPool.execute(
                        new ConnectionHandler(socket, clientHost));
                } catch (RejectedExecutionException e) {
                    closeSocket(socket);
                    tcpLog.log(Log.BRIEF,
                               "rejected connection from " + clientHost);
                }

            } catch (Throwable t) {
                try {
                    
                    if (serverSocket.isClosed()) {
                        break;
                    }

                    try {
                        if (tcpLog.isLoggable(Level.WARNING)) {
                            tcpLog.log(Level.WARNING,
                                       "accept loop for " + serverSocket +
                                       " throws", t);
                        }
                    } catch (Throwable tt) {
                    }
                } finally {
                   
                    if (socket != null) {
                        closeSocket(socket);
                    }
                }

                
                if (!(t instanceof SecurityException)) {
                    try {
                        TCPEndpoint.shedConnectionCaches();
                    } catch (Throwable tt) {
                    }
                }

                
                if (t instanceof Exception ||
                    t instanceof OutOfMemoryError ||
                    t instanceof NoClassDefFoundError)
                {
                    if (!continueAfterAcceptFailure(t)) {
                        return;
                    }
                    // continue loop
                } else if (t instanceof Error) {
                    throw (Error) t;
                } else {
                    throw new UndeclaredThrowableException(t);
                }
            }
        }
    }

我們進到connectionThreadPool.execute(new ConnectionHandler(socket, clientHost));,然後來到他的run方法

  public void run() {
        Thread t = Thread.currentThread();
        String name = t.getName();
        try {
            t.setName("RMI TCP Connection(" +
                      connectionCount.incrementAndGet() +
                      ")-" + remoteHost);
            AccessController.doPrivileged((PrivilegedAction<Void>)() -> {
                run0();
                return null;
            }, NOPERMS_ACC);
        } finally {
            t.setName(name);
        }
    }

他呼叫了run0

   private void run0() {
        TCPEndpoint endpoint = getEndpoint();
        int port = endpoint.getPort();

        threadConnectionHandler.set(this);

        // set socket to disable Nagle's algorithm (always send
        // immediately)
        // TBD: should this be left up to socket factory instead?
        try {
            socket.setTcpNoDelay(true);
        } catch (Exception e) {
            // if we fail to set this, ignore and proceed anyway
        }
        // set socket to timeout after excessive idle time
        try {
            if (connectionReadTimeout > 0)
                socket.setSoTimeout(connectionReadTimeout);
        } catch (Exception e) {
            // too bad, continue anyway
        }

        try {
            InputStream sockIn = socket.getInputStream();
            InputStream bufIn = sockIn.markSupported()
                    ? sockIn
                    : new BufferedInputStream(sockIn);

            // Read magic (or HTTP wrapper)
            bufIn.mark(4);
            DataInputStream in = new DataInputStream(bufIn);
            int magic = in.readInt();

            if (magic == POST) {
                tcpLog.log(Log.BRIEF, "decoding HTTP-wrapped call");

                // It's really a HTTP-wrapped request.  Repackage
                // the socket in a HttpReceiveSocket, reinitialize
                // sockIn and in, and reread magic.
                bufIn.reset();      // unread "POST"

                try {
                    socket = new HttpReceiveSocket(socket, bufIn, null);
                    remoteHost = "0.0.0.0";
                    sockIn = socket.getInputStream();
                    bufIn = new BufferedInputStream(sockIn);
                    in = new DataInputStream(bufIn);
                    magic = in.readInt();

                } catch (IOException e) {
                    throw new RemoteException("Error HTTP-unwrapping call",
                                              e);
                }
            }
            // bufIn's mark will invalidate itself when it overflows
            // so it doesn't have to be turned off

            // read and verify transport header
            short version = in.readShort();
            if (magic != TransportConstants.Magic ||
                version != TransportConstants.Version) {
                // protocol mismatch detected...
                // just close socket: this would recurse if we marshal an
                // exception to the client and the protocol at other end
                // doesn't match.
                closeSocket(socket);
                return;
            }

            OutputStream sockOut = socket.getOutputStream();
            BufferedOutputStream bufOut =
                new BufferedOutputStream(sockOut);
            DataOutputStream out = new DataOutputStream(bufOut);

            int remotePort = socket.getPort();

            if (tcpLog.isLoggable(Log.BRIEF)) {
                tcpLog.log(Log.BRIEF, "accepted socket from [" +
                                 remoteHost + ":" + remotePort + "]");
            }

            TCPEndpoint ep;
            TCPChannel ch;
            TCPConnection conn;

            // send ack (or nack) for protocol
            byte protocol = in.readByte();
            switch (protocol) {
            case TransportConstants.SingleOpProtocol:
                // no ack for protocol

                // create dummy channel for receiving messages
                ep = new TCPEndpoint(remoteHost, socket.getLocalPort(),
                                     endpoint.getClientSocketFactory(),
                                     endpoint.getServerSocketFactory());
                ch = new TCPChannel(TCPTransport.this, ep);
                conn = new TCPConnection(ch, socket, bufIn, bufOut);

                // read input messages
                handleMessages(conn, false);
                break;

            case TransportConstants.StreamProtocol:
                // send ack
                out.writeByte(TransportConstants.ProtocolAck);

                // suggest endpoint (in case client doesn't know host name)
                if (tcpLog.isLoggable(Log.VERBOSE)) {
                    tcpLog.log(Log.VERBOSE, "(port " + port +
                        ") " + "suggesting " + remoteHost + ":" +
                        remotePort);
                }

                out.writeUTF(remoteHost);
                out.writeInt(remotePort);
                out.flush();

                // read and discard (possibly bogus) endpoint
                // REMIND: would be faster to read 2 bytes then skip N+4
                String clientHost = in.readUTF();
                int    clientPort = in.readInt();
                if (tcpLog.isLoggable(Log.VERBOSE)) {
                    tcpLog.log(Log.VERBOSE, "(port " + port +
                        ") client using " + clientHost + ":" + clientPort);
                }

                // create dummy channel for receiving messages
                // (why not use clientHost and clientPort?)
                ep = new TCPEndpoint(remoteHost, socket.getLocalPort(),
                                     endpoint.getClientSocketFactory(),
                                     endpoint.getServerSocketFactory());
                ch = new TCPChannel(TCPTransport.this, ep);
                conn = new TCPConnection(ch, socket, bufIn, bufOut);

                // read input messages
                handleMessages(conn, true);
                break;

            case TransportConstants.MultiplexProtocol:
                if (tcpLog.isLoggable(Log.VERBOSE)) {
                    tcpLog.log(Log.VERBOSE, "(port " + port +
                        ") accepting multiplex protocol");
                }

                // send ack
                out.writeByte(TransportConstants.ProtocolAck);

                // suggest endpoint (in case client doesn't already have one)
                if (tcpLog.isLoggable(Log.VERBOSE)) {
                    tcpLog.log(Log.VERBOSE, "(port " + port +
                        ") suggesting " + remoteHost + ":" + remotePort);
                }

                out.writeUTF(remoteHost);
                out.writeInt(remotePort);
                out.flush();

                // read endpoint client has decided to use
                ep = new TCPEndpoint(in.readUTF(), in.readInt(),
                                     endpoint.getClientSocketFactory(),
                                     endpoint.getServerSocketFactory());
                if (tcpLog.isLoggable(Log.VERBOSE)) {
                    tcpLog.log(Log.VERBOSE, "(port " +
                        port + ") client using " +
                        ep.getHost() + ":" + ep.getPort());
                }

                ConnectionMultiplexer multiplexer;
                synchronized (channelTable) {
                    // create or find channel for this endpoint
                    ch = getChannel(ep);
                    multiplexer =
                        new ConnectionMultiplexer(ch, bufIn, sockOut,
                                                  false);
                    ch.useMultiplexer(multiplexer);
                }
                multiplexer.run();
                break;

            default:
                // protocol not understood, send nack and close socket
                out.writeByte(TransportConstants.ProtocolNack);
                out.flush();
                break;
            }

        } catch (IOException e) {
            // socket in unknown state: destroy socket
            tcpLog.log(Log.BRIEF, "terminated with exception:", e);
        } finally {
            closeSocket(socket);
        }
    }

這個run0就是個處理網路請求的,他重點在handleMessages(conn, false);,我們進去。

  void handleMessages(Connection conn, boolean persistent) {
    int port = getEndpoint().getPort();

    try {
        DataInputStream in = new DataInputStream(conn.getInputStream());
        do {
            int op = in.read();     // transport op
            if (op == -1) {
                if (tcpLog.isLoggable(Log.BRIEF)) {
                    tcpLog.log(Log.BRIEF, "(port " +
                        port + ") connection closed");
                }
                break;
            }

            if (tcpLog.isLoggable(Log.BRIEF)) {
                tcpLog.log(Log.BRIEF, "(port " + port +
                    ") op = " + op);
            }

            switch (op) {
            case TransportConstants.Call:
                // service incoming RMI call
                RemoteCall call = new StreamRemoteCall(conn);
                if (serviceCall(call) == false)
                    return;
                break;

            case TransportConstants.Ping:
                // send ack for ping
                DataOutputStream out =
                    new DataOutputStream(conn.getOutputStream());
                out.writeByte(TransportConstants.PingAck);
                conn.releaseOutputStream();
                break;

            case TransportConstants.DGCAck:
                DGCAckHandler.received(UID.read(in));
                break;

            default:
                throw new IOException("unknown transport op " + op);
            }
        } while (persistent);

    } catch (IOException e) {
        // exception during processing causes connection to close (below)
        if (tcpLog.isLoggable(Log.BRIEF)) {
            tcpLog.log(Log.BRIEF, "(port " + port +
                ") exception: ", e);
        }
    } finally {
        try {
            conn.close();
        } catch (IOException ex) {
            // eat exception
        }
    }
}

這個是看他傳進去欄位值,做case處理,預設的是serviceCall(call)

 public boolean serviceCall(final RemoteCall call) {
    try {
        /* read object id */
        final Remote impl;
        ObjID id;

        try {
            id = ObjID.read(call.getInputStream());
        } catch (java.io.IOException e) {
            throw new MarshalException("unable to read objID", e);
        }

        /* get the remote object */
        Transport transport = id.equals(dgcID) ? null : this;
        Target target =
            ObjectTable.getTarget(new ObjectEndpoint(id, transport));

        if (target == null || (impl = target.getImpl()) == null) {
            throw new NoSuchObjectException("no such object in table");
        }

        final Dispatcher disp = target.getDispatcher();
        target.incrementCallCount();
        try {
            /* call the dispatcher */
            transportLog.log(Log.VERBOSE, "call dispatcher");

            final AccessControlContext acc =
                target.getAccessControlContext();
            ClassLoader ccl = target.getContextClassLoader();

            ClassLoader savedCcl = Thread.currentThread().getContextClassLoader();

            try {
                setContextClassLoader(ccl);
                currentTransport.set(this);
                try {
                    java.security.AccessController.doPrivileged(
                        new java.security.PrivilegedExceptionAction<Void>() {
                        public Void run() throws IOException {
                            checkAcceptPermission(acc);
                            disp.dispatch(impl, call);
                            return null;
                        }
                    }, acc);
                } catch (java.security.PrivilegedActionException pae) {
                    throw (IOException) pae.getException();
                }
            } finally {
                setContextClassLoader(savedCcl);
                currentTransport.set(null);
            }

        } catch (IOException ex) {
            transportLog.log(Log.BRIEF,
                             "exception thrown by dispatcher: ", ex);
            return false;
        } finally {
            target.decrementCallCount();
        }

    } catch (RemoteException e) {

        // if calls are being logged, write out exception
        if (UnicastServerRef.callLog.isLoggable(Log.BRIEF)) {
            // include client host name if possible
            String clientHost = "";
            try {
                clientHost = "[" +
                    RemoteServer.getClientHost() + "] ";
            } catch (ServerNotActiveException ex) {
            }
            String message = clientHost + "exception: ";
            UnicastServerRef.callLog.log(Log.BRIEF, message, e);
        }

        
        try {
            ObjectOutput out = call.getResultStream(false);
            UnicastServerRef.clearStackTraces(e);
            out.writeObject(e);
            call.releaseOutputStream();

        } catch (IOException ie) {
            transportLog.log(Log.BRIEF,
                "exception thrown marshalling exception: ", ie);
            return false;
        }
    }

    return true;
}

這裡面的Target target =ObjectTable.getTarget(new ObjectEndpoint(id, transport));這個就是在我們傳進去的表裡面查詢

我們在if (target == null || (impl = target.getImpl()) == null)這裡下個斷點,然後客戶端請求。

然後我們往下走,走到 disp.dispatch(impl, call);,進去。

public void dispatch(Remote obj, RemoteCall call) throws IOException {
    
    int num;
    long op;

    try {
        // read remote call header
        ObjectInput in;
        try {
            in = call.getInputStream();
            num = in.readInt();
            if (num >= 0) {
                if (skel != null) {
                    oldDispatch(obj, call, num);
                    return;
                } else {
                    throw new UnmarshalException(
                        "skeleton class not found but required " +
                        "for client version");
                }
            }
            op = in.readLong();
        } catch (Exception readEx) {
            throw new UnmarshalException("error unmarshalling call header",
                                         readEx);
        }

       
        MarshalInputStream marshalStream = (MarshalInputStream) in;
        marshalStream.skipDefaultResolveClass();

        Method method = hashToMethod_Map.get(op);
        if (method == null) {
            throw new UnmarshalException("unrecognized method hash: " +
                "method not supported by remote object");
        }

        // if calls are being logged, write out object id and operation
        logCall(obj, method);

        // unmarshal parameters
        Class<?>[] types = method.getParameterTypes();
        Object[] params = new Object[types.length];

        try {
            unmarshalCustomCallData(in);
            for (int i = 0; i < types.length; i++) {
                params[i] = unmarshalValue(types[i], in);
            }
        } catch (java.io.IOException e) {
            throw new UnmarshalException(
                "error unmarshalling arguments", e);
        } catch (ClassNotFoundException e) {
            throw new UnmarshalException(
                "error unmarshalling arguments", e);
        } finally {
            call.releaseInputStream();
        }

        // make upcall on remote object
        Object result;
        try {
            result = method.invoke(obj, params);
        } catch (InvocationTargetException e) {
            throw e.getTargetException();
        }

        // marshal return value
        try {
            ObjectOutput out = call.getResultStream(true);
            Class<?> rtype = method.getReturnType();
            if (rtype != void.class) {
                marshalValue(rtype, result, out);
            }
        } catch (IOException ex) {
            throw new MarshalException("error marshalling return", ex);
           
        }
    } catch (Throwable e) {
        logCallException(e);

        ObjectOutput out = call.getResultStream(false);
        if (e instanceof Error) {
            e = new ServerError(
                "Error occurred in server thread", (Error) e);
        } else if (e instanceof RemoteException) {
            e = new ServerException(
                "RemoteException occurred in server thread",
                (Exception) e);
        }
        if (suppressStackTraces) {
            clearStackTraces(e);
        }
        out.writeObject(e);
    } finally {
        call.releaseInputStream(); // in case skeleton doesn't
        call.releaseOutputStream();
    }
}

然後我們進 oldDispatch(obj, call, num);

public void oldDispatch(Remote obj, RemoteCall call, int op)
    throws IOException
{
    long hash;              // hash for matching stub with skeleton

    try {
        // read remote call header
        ObjectInput in;
        try {
            in = call.getInputStream();
            try {
                Class<?> clazz = Class.forName("sun.rmi.transport.DGCImpl_Skel");
                if (clazz.isAssignableFrom(skel.getClass())) {
                    ((MarshalInputStream)in).useCodebaseOnly();
                }
            } catch (ClassNotFoundException ignore) { }
            hash = in.readLong();
        } catch (Exception readEx) {
            throw new UnmarshalException("error unmarshalling call header",
                                         readEx);
        }

        // if calls are being logged, write out object id and operation
        logCall(obj, skel.getOperations()[op]);
        unmarshalCustomCallData(in);
        // dispatch to skeleton for remote object
        skel.dispatch(obj, call, op, hash);

    } catch (Throwable e) {
        logCallException(e);

        ObjectOutput out = call.getResultStream(false);
        if (e instanceof Error) {
            e = new ServerError(
                "Error occurred in server thread", (Error) e);
        } else if (e instanceof RemoteException) {
            e = new ServerException(
                "RemoteException occurred in server thread",
                (Exception) e);
        }
        if (suppressStackTraces) {
            clearStackTraces(e);
        }
        out.writeObject(e);
    } finally {
        call.releaseInputStream(); // in case skeleton doesn't
        call.releaseOutputStream();
    }
}

然後我們進skel.dispatch(obj, call, op, hash);

 public void dispatch(Remote var1, RemoteCall var2, int var3, long var4) throws Exception {
    if (var4 != 4905912898345647071L) {
        throw new SkeletonMismatchException("interface hash mismatch");
    } else {
        RegistryImpl var6 = (RegistryImpl)var1;
        String var7;
        Remote var8;
        ObjectInput var10;
        ObjectInput var11;
        switch (var3) {
            case 0:
                try {
                    var11 = var2.getInputStream();
                    var7 = (String)var11.readObject();
                    var8 = (Remote)var11.readObject();
                } catch (IOException var94) {
                    throw new UnmarshalException("error unmarshalling arguments", var94);
                } catch (ClassNotFoundException var95) {
                    throw new UnmarshalException("error unmarshalling arguments", var95);
                } finally {
                    var2.releaseInputStream();
                }

                var6.bind(var7, var8);

                try {
                    var2.getResultStream(true);
                    break;
                } catch (IOException var93) {
                    throw new MarshalException("error marshalling return", var93);
                }
            case 1:
                var2.releaseInputStream();
                String[] var97 = var6.list();

                try {
                    ObjectOutput var98 = var2.getResultStream(true);
                    var98.writeObject(var97);
                    break;
                } catch (IOException var92) {
                    throw new MarshalException("error marshalling return", var92);
                }
            case 2:
                try {
                    var10 = var2.getInputStream();
                    var7 = (String)var10.readObject();
                } catch (IOException var89) {
                    throw new UnmarshalException("error unmarshalling arguments", var89);
                } catch (ClassNotFoundException var90) {
                    throw new UnmarshalException("error unmarshalling arguments", var90);
                } finally {
                    var2.releaseInputStream();
                }

                var8 = var6.lookup(var7);

                try {
                    ObjectOutput var9 = var2.getResultStream(true);
                    var9.writeObject(var8);
                    break;
                } catch (IOException var88) {
                    throw new MarshalException("error marshalling return", var88);
                }
            case 3:
                try {
                    var11 = var2.getInputStream();
                    var7 = (String)var11.readObject();
                    var8 = (Remote)var11.readObject();
                } catch (IOException var85) {
                    throw new UnmarshalException("error unmarshalling arguments", var85);
                } catch (ClassNotFoundException var86) {
                    throw new UnmarshalException("error unmarshalling arguments", var86);
                } finally {
                    var2.releaseInputStream();
                }

                var6.rebind(var7, var8);

                try {
                    var2.getResultStream(true);
                    break;
                } catch (IOException var84) {
                    throw new MarshalException("error marshalling return", var84);
                }
            case 4:
                try {
                    var10 = var2.getInputStream();
                    var7 = (String)var10.readObject();
                } catch (IOException var81) {
                    throw new UnmarshalException("error unmarshalling arguments", var81);
                } catch (ClassNotFoundException var82) {
                    throw new UnmarshalException("error unmarshalling arguments", var82);
                } finally {
                    var2.releaseInputStream();
                }

                var6.unbind(var7);

                try {
                    var2.getResultStream(true);
                    break;
                } catch (IOException var80) {
                    throw new MarshalException("error marshalling return", var80);
                }
            default:
                throw new UnmarshalException("invalid method number");
        }

    }
}

這裡就是客戶端打註冊中心的攻擊地方。

攻擊方式

先介紹一下這段原始碼吧,很長,基本都是在做 case 的工作。
我們與註冊中心進行互動可以使用如下幾種方式:

 list

 bind

 rebind

 unbind

 lookup

這幾種方法位於 RegistryImpl_Skel#dispatch 中,也就是我們現在 dispatch 這個方法的地方。

如果存在對傳入的物件呼叫 readObject 方法,則可以利用,dispatch 裡面對應關係如下:

0->bind
1->list
2->lookup
3->rebind
4->unbind

只要中間是有反序列化就是可以攻擊的,這裡只有list沒有readobject,所以list不行

8.客戶端請求服務端是怎麼處理的

他一樣是會進入這個dispatch

 public void dispatch(Remote obj, RemoteCall call) throws IOException {
    
    int num;
    long op;

    try {
        // read remote call header
        ObjectInput in;
        try {
            in = call.getInputStream();
            num = in.readInt();
            if (num >= 0) {
                if (skel != null) {
                    oldDispatch(obj, call, num);
                    return;
                } else {
                    throw new UnmarshalException(
                        "skeleton class not found but required " +
                        "for client version");
                }
            }
            op = in.readLong();
        } catch (Exception readEx) {
            throw new UnmarshalException("error unmarshalling call header",
                                         readEx);
        }

       
        MarshalInputStream marshalStream = (MarshalInputStream) in;
        marshalStream.skipDefaultResolveClass();

        Method method = hashToMethod_Map.get(op);
        if (method == null) {
            throw new UnmarshalException("unrecognized method hash: " +
                "method not supported by remote object");
        }

        // if calls are being logged, write out object id and operation
        logCall(obj, method);

        // unmarshal parameters
        Class<?>[] types = method.getParameterTypes();
        Object[] params = new Object[types.length];

        try {
            unmarshalCustomCallData(in);
            for (int i = 0; i < types.length; i++) {
                params[i] = unmarshalValue(types[i], in);
            }
        } catch (java.io.IOException e) {
            throw new UnmarshalException(
                "error unmarshalling arguments", e);
        } catch (ClassNotFoundException e) {
            throw new UnmarshalException(
                "error unmarshalling arguments", e);
        } finally {
            call.releaseInputStream();
        }

        // make upcall on remote object
        Object result;
        try {
            result = method.invoke(obj, params);
        } catch (InvocationTargetException e) {
            throw e.getTargetException();
        }

        // marshal return value
        try {
            ObjectOutput out = call.getResultStream(true);
            Class<?> rtype = method.getReturnType();
            if (rtype != void.class) {
                marshalValue(rtype, result, out);
            }
        } catch (IOException ex) {
            throw new MarshalException("error marshalling return", ex);
           
        }
    } catch (Throwable e) {
        logCallException(e);

        ObjectOutput out = call.getResultStream(false);
        if (e instanceof Error) {
            e = new ServerError(
                "Error occurred in server thread", (Error) e);
        } else if (e instanceof RemoteException) {
            e = new ServerException(
                "RemoteException occurred in server thread",
                (Exception) e);
        }
        if (suppressStackTraces) {
            clearStackTraces(e);
        }
        out.writeObject(e);
    } finally {
        call.releaseInputStream(); // in case skeleton doesn't
        call.releaseOutputStream();
    }
}

但是不一樣的是他不是進入if (skel != null),因為skelnull
,然後我們接著往下走

     for (int i = 0; i < types.length; i++) {
                params[i] = unmarshalValue(types[i], in);
            }

到這裡,他和之前一樣是反序列化,是存在漏洞的。

9.DGC

DGC是用於記憶體回收的,且埠隨機。

他在exportObjectputTarget裡面

public void exportObject(Target target) throws RemoteException {
    target.setExportedTransport(this);
    ObjectTable.putTarget(target);
}

進去putTarget
static void putTarget(Target target) throws ExportException {
ObjectEndpoint oe = target.getObjectEndpoint();
WeakRef weakImpl = target.getWeakImpl();

    if (DGCImpl.dgcLog.isLoggable(Log.VERBOSE)) {
        DGCImpl.dgcLog.log(Log.VERBOSE, "add object " + oe);
    }

    synchronized (tableLock) {
        /**
         * Do nothing if impl has already been collected (see 6597112). Check while
         * holding tableLock to ensure that Reaper cannot process weakImpl in between
         * null check and put/increment effects.
         */
        if (target.getImpl() != null) {
            if (objTable.containsKey(oe)) {
                throw new ExportException(
                    "internal error: ObjID already in use");
            } else if (implTable.containsKey(weakImpl)) {
                throw new ExportException("object already exported");
            }

            objTable.put(oe, target);
            implTable.put(weakImpl, target);

            if (!target.isPermanent()) {
                incrementKeepAliveCount();
            }
        }
    }
}

然後我們看到裡面的if

 if (DGCImpl.dgcLog.isLoggable(Log.VERBOSE)) {
        DGCImpl.dgcLog.log(Log.VERBOSE, "add object " + oe);
    }

dgclog是個靜態變數

 static final Log dgcLog = Log.getLog("sun.rmi.dgc", "dgc",
    LogStream.parseLevel(AccessController.doPrivileged(
        new GetPropertyAction("sun.rmi.dgc.logLevel"))));

所以他會呼叫DGCImpl裡面的靜態方法

static {
    
    AccessController.doPrivileged(new PrivilegedAction<Void>() {
        public Void run() {
            ClassLoader savedCcl =
                Thread.currentThread().getContextClassLoader();
            try {
                Thread.currentThread().setContextClassLoader(
                    ClassLoader.getSystemClassLoader());

                try {
                    dgc = new DGCImpl();
                    ObjID dgcID = new ObjID(ObjID.DGC_ID);
                    LiveRef ref = new LiveRef(dgcID, 0);
                    UnicastServerRef disp = new UnicastServerRef(ref);
                    Remote stub =
                        Util.createProxy(DGCImpl.class,
                                         new UnicastRef(ref), true);
                    disp.setSkeleton(dgc);

                    Permissions perms = new Permissions();
                    perms.add(new SocketPermission("*", "accept,resolve"));
                    ProtectionDomain[] pd = { new ProtectionDomain(null, perms) };
                    AccessControlContext acceptAcc = new AccessControlContext(pd);

                    Target target = AccessController.doPrivileged(
                        new PrivilegedAction<Target>() {
                            public Target run() {
                                return new Target(dgc, disp, stub, dgcID, true);
                            }
                        }, acceptAcc);

                    ObjectTable.putTarget(target);
                } catch (RemoteException e) {
                    throw new Error(
                        "exception initializing server-side DGC", e);
                }
            } finally {
                Thread.currentThread().setContextClassLoader(savedCcl);
            }
            return null;
        }
    });

來到createProxy和之前的建立代理一樣

public static Remote createProxy(Class<?> implClass,
                                 RemoteRef clientRef,
                                 boolean forceStubUse)
    throws StubNotFoundException
{
    Class<?> remoteClass;

    try {
        remoteClass = getRemoteClass(implClass);
    } catch (ClassNotFoundException ex ) {
        throw new StubNotFoundException(
            "object does not implement a remote interface: " +
            implClass.getName());
    }

    if (forceStubUse ||
        !(ignoreStubClasses || !stubClassExists(remoteClass)))
    {
        return createStub(remoteClass, clientRef);
    }

    final ClassLoader loader = implClass.getClassLoader();
    final Class<?>[] interfaces = getRemoteInterfaces(implClass);
    final InvocationHandler handler =
        new RemoteObjectInvocationHandler(clientRef);

    /* REMIND: private remote interfaces? */

    try {
        return AccessController.doPrivileged(new PrivilegedAction<Remote>() {
            public Remote run() {
                return (Remote) Proxy.newProxyInstance(loader,
                                                       interfaces,
                                                       handler);
            }});
    } catch (IllegalArgumentException e) {
        throw new StubNotFoundException("unable to create proxy", e);
    }
}

進到createStub

  private static RemoteStub createStub(Class<?> remoteClass, RemoteRef ref)
    throws StubNotFoundException
{
    String stubname = remoteClass.getName() + "_Stub";

   
    try {
        Class<?> stubcl =
            Class.forName(stubname, false, remoteClass.getClassLoader());
        Constructor<?> cons = stubcl.getConstructor(stubConsParamTypes);
        return (RemoteStub) cons.newInstance(new Object[] { ref });

    } catch (ClassNotFoundException e) {
        throw new StubNotFoundException(
            "Stub class not found: " + stubname, e);
    } catch (NoSuchMethodException e) {
        throw new StubNotFoundException(
            "Stub class missing constructor: " + stubname, e);
    } catch (InstantiationException e) {
        throw new StubNotFoundException(
            "Can't create instance of stub class: " + stubname, e);
    } catch (IllegalAccessException e) {
        throw new StubNotFoundException(
            "Stub class constructor not public: " + stubname, e);
    } catch (InvocationTargetException e) {
        throw new StubNotFoundException(
            "Exception creating instance of stub class: " + stubname, e);
    } catch (ClassCastException e) {
        throw new StubNotFoundException(
            "Stub class not instance of RemoteStub: " + stubname, e);
    }
}

這裡和註冊中心建立遠端服務一樣,嘗試是否可以獲取到這一個類 DGCImpl_Stub

我們重點關注一下 DGCStub 裡面有漏洞的地方。

DGCImpl_Stub 這個類下,它有兩個方法,一個是 clean,另外一個是 dirtyclean 就是”強”清除記憶體,dirty 就是”弱”清除記憶體。

 public void clean(ObjID[] var1, long var2, VMID var4, boolean var5) throws RemoteException {
    try {
        RemoteCall var6 = super.ref.newCall(this, operations, 0, -669196253586618813L);

        try {
            ObjectOutput var7 = var6.getOutputStream();
            var7.writeObject(var1);
            var7.writeLong(var2);
            var7.writeObject(var4);
            var7.writeBoolean(var5);
        } catch (IOException var8) {
            throw new MarshalException("error marshalling arguments", var8);
        }

        super.ref.invoke(var6);
        super.ref.done(var6);
    } catch (RuntimeException var9) {
        throw var9;
    } catch (RemoteException var10) {
        throw var10;
    } catch (Exception var11) {
        throw new UnexpectedException("undeclared checked exception", var11);
    }
}

public Lease dirty(ObjID[] var1, long var2, Lease var4) throws RemoteException {
    try {
        RemoteCall var5 = super.ref.newCall(this, operations, 1, -669196253586618813L);

        try {
            ObjectOutput var6 = var5.getOutputStream();
            var6.writeObject(var1);
            var6.writeLong(var2);
            var6.writeObject(var4);
        } catch (IOException var20) {
            throw new MarshalException("error marshalling arguments", var20);
        }

        super.ref.invoke(var5);

        Lease var24;
        try {
            ObjectInput var9 = var5.getInputStream();
            var24 = (Lease)var9.readObject();
        } catch (IOException var17) {
            throw new UnmarshalException("error unmarshalling return", var17);
        } catch (ClassNotFoundException var18) {
            throw new UnmarshalException("error unmarshalling return", var18);
        } finally {
            super.ref.done(var5);
        }

        return var24;
    } catch (RuntimeException var21) {
        throw var21;
    } catch (RemoteException var22) {
        throw var22;
    } catch (Exception var23) {
        throw new UnexpectedException("undeclared checked exception", var23);
    }
}

}

都是存在漏洞的

我們到DGCImpl_Skel看看

    public void dispatch(Remote var1, RemoteCall var2, int var3, long var4) throws Exception {
    if (var4 != -669196253586618813L) {
        throw new SkeletonMismatchException("interface hash mismatch");
    } else {
        DGCImpl var6 = (DGCImpl)var1;
        ObjID[] var7;
        long var8;
        switch (var3) {
            case 0:
                VMID var39;
                boolean var40;
                try {
                    ObjectInput var14 = var2.getInputStream();
                    var7 = (ObjID[])var14.readObject();
                    var8 = var14.readLong();
                    var39 = (VMID)var14.readObject();
                    var40 = var14.readBoolean();
                } catch (IOException var36) {
                    throw new UnmarshalException("error unmarshalling arguments", var36);
                } catch (ClassNotFoundException var37) {
                    throw new UnmarshalException("error unmarshalling arguments", var37);
                } finally {
                    var2.releaseInputStream();
                }

                var6.clean(var7, var8, var39, var40);

                try {
                    var2.getResultStream(true);
                    break;
                } catch (IOException var35) {
                    throw new MarshalException("error marshalling return", var35);
                }
            case 1:
                Lease var10;
                try {
                    ObjectInput var13 = var2.getInputStream();
                    var7 = (ObjID[])var13.readObject();
                    var8 = var13.readLong();
                    var10 = (Lease)var13.readObject();
                } catch (IOException var32) {
                    throw new UnmarshalException("error unmarshalling arguments", var32);
                } catch (ClassNotFoundException var33) {
                    throw new UnmarshalException("error unmarshalling arguments", var33);
                } finally {
                    var2.releaseInputStream();
                }

                Lease var11 = var6.dirty(var7, var8, var10);

                try {
                    ObjectOutput var12 = var2.getResultStream(true);
                    var12.writeObject(var11);
                    break;
                } catch (IOException var31) {
                    throw new MarshalException("error marshalling return", var31);
                }
            default:
                throw new UnmarshalException("invalid method number");
        }

    }
}

也是存在漏洞的

相關文章