RemoteFileDescriptorBase.java

/*
 * junixsocket
 *
 * Copyright 2009-2024 Christian Kohlschütter
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.newsclub.net.unix.rmi;

import java.io.Closeable;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.Externalizable;
import java.io.FileDescriptor;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.io.OutputStream;
import java.net.SocketException;
import java.util.Objects;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;

import org.newsclub.net.unix.AFServerSocket;
import org.newsclub.net.unix.AFSocket;
import org.newsclub.net.unix.AFSocketAddress;
import org.newsclub.net.unix.AFUNIXSocket;
import org.newsclub.net.unix.FileDescriptorAccess;
import org.newsclub.net.unix.server.AFSocketServer;

import com.kohlschutter.annotations.compiletime.SuppressFBWarnings;

/**
 * A wrapper that allows a {@link FileDescriptor} be sent via RMI over AF_UNIX sockets.
 *
 * @author Christian Kohlschütter
 * @param <T> The resource type.
 * @see RemoteFileInput
 * @see RemoteFileOutput
 */
public abstract class RemoteFileDescriptorBase<T> implements Externalizable, Closeable,
    FileDescriptorAccess {
  private static final String PROP_SERVER_TIMEOUT =
      "org.newsclub.net.unix.rmi.rfd-server-timeout-millis";
  private static final String PROP_CONNECT_TIMEOUT =
      "org.newsclub.net.unix.rmi.rfd-connect-timeout-millis";

  private static final int SERVER_TIMEOUT = //
      parseTimeoutMillis(System.getProperty(PROP_SERVER_TIMEOUT, "10000"), false);
  private static final int CONNECT_TIMEOUT = //
      parseTimeoutMillis(System.getProperty(PROP_CONNECT_TIMEOUT, "1000"), true);

  static final int MAGIC_VALUE_MASK = 0x00FD0000;
  static final int BIT_READABLE = 1 << 0;
  static final int BIT_WRITABLE = 1 << 1;

  private static final long serialVersionUID = 1L;

  private final AtomicReference<DataInputStream> remoteConnection = new AtomicReference<>();
  private final AtomicReference<AFUNIXSocket> remoteServer = new AtomicReference<>();

  /**
   * An optional, closeable resource that is related to this instance. If the reference is non-null,
   * this will be closed upon {@link #close()}.
   *
   * For unidirectional implementations, this could be the corresponding input/output stream. For
   * bidirectional implementations (e.g., a Socket, Pipe, etc.), this should close both directions.
   */
  protected final transient AtomicReference<T> resource = new AtomicReference<>();

  private int magicValue;
  private transient FileDescriptor fd;
  private AFUNIXRMISocketFactory socketFactory;

  /**
   * Creates an uninitialized instance; used for externalization.
   *
   * @see #readExternal(ObjectInput)
   */
  public RemoteFileDescriptorBase() {
  }

  RemoteFileDescriptorBase(AFUNIXRMISocketFactory socketFactory, T stream, FileDescriptor fd,
      int magicValue) {
    this.resource.set(stream);
    this.socketFactory = socketFactory;
    this.fd = fd;
    this.magicValue = magicValue;
  }

  @Override
  @SuppressWarnings("PMD.ExceptionAsFlowControl")
  public final void writeExternal(ObjectOutput objOut) throws IOException {
    if (fd == null || !fd.valid()) {
      throw new IOException("No or invalid file descriptor");
    }
    final int randomValue = ThreadLocalRandom.current().nextInt();

    int localPort;
    try {
      AFServerSocket<?> serverSocket = (AFServerSocket<?>) socketFactory.createServerSocket(0);
      localPort = serverSocket.getLocalPort();

      AFSocketServer<?> server = new AFSocketServer<AFSocketAddress>(serverSocket) {
        @Override
        protected void doServeSocket(AFSocket<?> socket) throws IOException {
          AFUNIXSocket unixSocket = (AFUNIXSocket) socket;
          try (DataOutputStream out = new DataOutputStream(socket.getOutputStream());
              InputStream in = socket.getInputStream();) {
            unixSocket.setOutboundFileDescriptors(fd);
            out.writeInt(randomValue);

            try {
              socket.setSoTimeout(CONNECT_TIMEOUT);
            } catch (IOException e) {
              // ignore
            }

            // This call blocks until the remote is done with the file descriptor, or we time out.
            int response = in.read();
            if (response != 1) {
              if (response == -1) {
                // EOF, remote terminated
              } else {
                throw new IOException("Unexpected response: " + response);
              }
            }
          } finally {
            stop();
          }
        }

        @Override
        protected void onServerStopped(AFServerSocket<?> socket) {
          try {
            serverSocket.close();
          } catch (IOException e) {
            // ignore
          }
        }

      };
      @SuppressWarnings("unused")
      ScheduledFuture<IOException> unused = server.startThenStopAfter(SERVER_TIMEOUT,
          TimeUnit.MILLISECONDS);
    } catch (IOException e) {
      objOut.writeObject(e);
      throw e;
    }

    objOut.writeObject(socketFactory);
    objOut.writeInt(magicValue);
    objOut.writeInt(randomValue);
    objOut.writeInt(localPort);
    objOut.flush();
  }

  @SuppressWarnings("resource")
  @Override
  public final void readExternal(ObjectInput objIn) throws IOException, ClassNotFoundException {
    DataInputStream in1 = remoteConnection.getAndSet(null);
    if (in1 != null) {
      in1.close();
    }

    Object obj = objIn.readObject();
    if (obj instanceof IOException) {
      IOException e = new IOException("Could not read RemoteFileDescriptor");
      e.addSuppressed((IOException) obj);
      throw e;
    }
    this.socketFactory = (AFUNIXRMISocketFactory) obj;

    // Since ancillary messages can only be read in combination with real data, we read and verify a
    // magic value
    this.magicValue = objIn.readInt();
    if ((magicValue & MAGIC_VALUE_MASK) != MAGIC_VALUE_MASK) {
      throw new IOException("Unexpected magic value: " + Integer.toHexString(magicValue));
    }
    final int randomValue = objIn.readInt();
    int port = objIn.readInt();

    AFUNIXSocket socket = (AFUNIXSocket) socketFactory.createSocket("", port);
    if (remoteServer.getAndSet(socket) != null) {
      throw new IllegalStateException("remoteServer was not null");
    }

    try {
      socket.setSoTimeout(CONNECT_TIMEOUT);
    } catch (IOException e) {
      // ignore
    }

    in1 = new DataInputStream(socket.getInputStream());
    this.remoteConnection.set(in1);
    socket.ensureAncillaryReceiveBufferSize(128);

    int random = in1.readInt();

    if (random != randomValue) {
      throw new IOException("Invalid socket connection");
    }
    FileDescriptor[] descriptors = socket.getReceivedFileDescriptors();

    if (descriptors == null || descriptors.length != 1) {
      throw new IOException("Did not receive exactly 1 file descriptor but " + (descriptors == null
          ? 0 : descriptors.length));
    }

    this.fd = descriptors[0];
  }

  /**
   * Returns the file descriptor.
   *
   * This is either the original one that was specified in the constructor or a copy that was sent
   * via RMI over an AF_UNIX connection as part of an ancillary message.
   *
   * @return The file descriptor.
   */
  @Override
  @SuppressFBWarnings("EI_EXPOSE_REP")
  public final FileDescriptor getFileDescriptor() {
    return fd;
  }

  /**
   * Returns the "magic value" for this type of file descriptor.
   *
   * The magic value consists of an indicator ("this is a file descriptor") as well as its
   * capabilities (read/write). It is used to prevent, for example, converting an output stream to
   * an input stream.
   *
   * @return The magic value.
   */
  protected final int getMagicValue() {
    return magicValue;
  }

  @SuppressWarnings("resource")
  @Override
  public void close() throws IOException {
    DataInputStream in1 = remoteConnection.getAndSet(null);
    if (in1 != null) {
      try {
        in1.close();
      } catch (SocketException e) {
        // ignore
      }
    }

    AFUNIXSocket remoteSocket = remoteServer.getAndSet(null);
    if (remoteSocket != null) {
      try (OutputStream out = remoteSocket.getOutputStream()) {
        out.write(1);
      } catch (SocketException e) {
        // ignore
      }
      remoteSocket.close();
    }

    @SuppressWarnings("null")
    T c = this.resource.getAndSet(null);
    if (c != null) {
      if (c instanceof Closeable) {
        ((Closeable) c).close();
      }
    }
  }

  private static int parseTimeoutMillis(String s, boolean zeroPermitted) {
    Objects.requireNonNull(s);
    int duration;
    try {
      duration = Integer.parseInt(s);
    } catch (Exception e) {
      throw new IllegalArgumentException("Illegal timeout value: " + s, e);
    }
    if (duration < 0 || (duration == 0 && !zeroPermitted)) {
      throw new IllegalArgumentException("Illegal timeout value: " + s);
    }
    return duration;
  }
}