Package org.platformlayer.ssh.mina

Source Code of org.platformlayer.ssh.mina.MinaSshConnection

package org.platformlayer.ssh.mina;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.security.KeyPair;

import org.apache.sshd.ClientChannel;
import org.apache.sshd.agent.SshAgent;
import org.apache.sshd.agent.local.LocalAgentFactory;
import org.apache.sshd.client.channel.ChannelSession;
import org.apache.sshd.client.channel.ForwardLocalPort;
import org.apache.sshd.client.channel.SshTunnelSocket;
import org.apache.sshd.common.RuntimeSshException;
import org.platformlayer.ExceptionUtils;
import org.platformlayer.ops.process.ProcessExecution;
import org.platformlayer.ops.ssh.IServerKeyVerifier;
import org.platformlayer.ops.ssh.SshConnection;
import org.platformlayer.ops.ssh.SshException;
import org.platformlayer.ops.ssh.SshPortForward;
import org.platformlayer.ssh.SshConnectionInfo;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.fathomdb.TimeSpan;

public class MinaSshConnection extends SshConnection {
  private static final Logger log = LoggerFactory.getLogger(MinaSshConnection.class);

  public static final TimeSpan DEFAULT_SSH_EXECUTE_TIMEOUT = new TimeSpan("15s");

  public static final TimeSpan DEFAULT_SSH_CONNECT_TIMEOUT = new TimeSpan("15s");
  public static final TimeSpan DEFAULT_SSH_KEY_EXECUTE_TIMEOUT = new TimeSpan("90s");

  final MinaSshContext minaSshContext;

  MinaSshConnectionWrapper sshConnection;

  KeyPair keyPair;

  MinaSshContext getSshContext() {
    return minaSshContext;
  }

  @Override
  public synchronized void close() {
    if (sshConnection != null) {
      sshConnection.close();
    }

    sshConnection = null;
  }

  public synchronized void closeAndRemoveFromPool() {
    close();
  }

  protected MinaSshConnectionWrapper ensureConnected() throws IOException, SshException {
    if (sshConnection == null) {
      getSshConnection();
    }

    activateConnection(sshConnection);

    return sshConnection;
  }

  protected synchronized MinaSshConnectionWrapper getSshConnection() throws IOException, SshException {
    if (getHost() == null) {
      throw new NullPointerException("getHost() returns null");
    }
    if (getKeyPair() == null) {
      throw new NullPointerException("privateKey is required");
    }
    SshConnectionInfo connectionInfo = new SshConnectionInfo(getHost(), getPort(), getUser(), getKeyPair());

    if (sshConnection != null) {
      throw new IllegalStateException("Attempt to get SSH connection when connection not previously closed");
    }

    sshConnection = MinaSshConnectionWrapper.build(getSshContext(), connectionInfo);

    activateConnection(sshConnection);

    return sshConnection;
  }

  private void activateConnection(MinaSshConnectionWrapper connection) throws IOException, SshException {
    activateConnection(connection, DEFAULT_SSH_CONNECT_TIMEOUT, DEFAULT_SSH_KEY_EXECUTE_TIMEOUT);
  }

  private void activateConnection(MinaSshConnectionWrapper sshConnection, TimeSpan connectTimeout,
      TimeSpan keyExchangeTimeout) throws IOException, SshException {
    boolean okay = false;
    try {
      if (!sshConnection.isConnected()) {
        log.info("Making new SSH connection to " + getHost());

        sshConnection.setServerKeyVerifier(this.getServerKeyVerifier());

        if (connectTimeout == null) {
          connectTimeout = TimeSpan.ZERO;
        }
        if (keyExchangeTimeout == null) {
          keyExchangeTimeout = TimeSpan.ZERO;
        }

        sshConnection.connect(connectTimeout);
        if (!sshConnection.isConnected()) {
          throw new IllegalStateException("Connection completed, but could not get connection details");
        }
      } else {
        IServerKeyVerifier myServerKeyVerifier = getServerKeyVerifier();
        myServerKeyVerifier.verifyPooled(sshConnection.getServerKeyVerifier());
      }

      if (!sshConnection.isAuthenticationComplete()) {
        // Authenticate
        boolean isAuthenticated = sshConnection.authenticateWithPublicKey(getUser(), getKeyPair(),
            keyExchangeTimeout);

        if (isAuthenticated == false) {
          // This happens when we upload the wrong public_key...
          // double check if you hit this!
          // This also happens if we're running a command that isn't
          // valid
          throw new SshException("Authentication failed.  Tried to connect to " + getUser() + "@"
              + sshConnection.getConnectionInfo().getHost());
        } else {
          log.debug("SSH authentication succeeded");
        }
      }

      okay = true;
    } finally {
      if (!okay) {
        // If we fail to activate for any reason, we reset the
        // connection so that we start clean
        log.info("Resetting connection after failure to connect");
        close();
      }
    }
  }

  public MinaSshConnection(MinaSshContext context) {
    super();
    this.minaSshContext = context;
  }

  @Override
  protected ProcessExecution sshExecute0(String command, TimeSpan timeout) throws SshException, IOException,
      InterruptedException {
    try {
      ByteArrayOutputStream stdoutBinary = new ByteArrayOutputStream();
      ByteArrayOutputStream stderrBinary = new ByteArrayOutputStream();

      int exitCode = sshExecute(command, stdoutBinary, stderrBinary, null, timeout);

      ProcessExecution processExecution = new ProcessExecution(command, exitCode, stdoutBinary.toByteArray(),
          stderrBinary.toByteArray());
      return processExecution;
    } catch (RuntimeSshException e) {
      throw new SshException("Unexpected SSH error", e);
    }
  }

  int sshExecute(String command, final OutputStream stdout, final OutputStream stderr, ProcessStartListener listener,
      TimeSpan timeout) throws SshException, IOException, InterruptedException {
    ChannelSession sshChannel = null;
    try {
      sshChannel = ensureConnected().openSession(command);

      sshChannel.setIn(new ByteArrayInputStream(new byte[0]));
      sshChannel.setOut(stdout);
      sshChannel.setErr(stderr);
      try {
        sshChannel.open().await(DEFAULT_SSH_CONNECT_TIMEOUT.getTotalMilliseconds());
      } catch (Exception e) {
        ExceptionUtils.handleInterrupted(e);
        throw new SshException("Ssh error opening channel", e);
      }

      if (listener != null) {
        throw new UnsupportedOperationException();
        // listener.startedProcess(sshChannel.getStdin());
      }

      if (timeout == null) {
        timeout = TimeSpan.ZERO;
      }

      // Wait for everything to finish
      int flags = sshChannel.waitFor(ClientChannel.EOF | ClientChannel.CLOSED, timeout.getTotalMilliseconds());
      if ((flags & ClientChannel.TIMEOUT) != 0) {
        closeAndRemoveFromPool();
        throw new SshException("Timeout while waiting for SSH task to complete.  Timeout was " + timeout);
      }

      flags = sshChannel.waitFor(ClientChannel.EXIT_STATUS, 30000);
      if ((flags & ClientChannel.TIMEOUT) != 0) {
        closeAndRemoveFromPool();
        throw new SshException("Timeout while waiting for exit code.  Timeout was " + timeout);
      }

      Integer exitCode = getExitStatus(sshChannel);

      if (exitCode == null) {
        closeAndRemoveFromPool();
        throw new SshException("No exit code returned");
      }

      return exitCode;
    } finally {
      if (sshChannel != null) {
        sshChannel.close(false);
      }
    }
  }

  private Integer getExitStatus(ClientChannel sshChannel) {
    Integer exitCode = sshChannel.getExitStatus();
    return exitCode;
  }

  @Override
  protected void sshCopyData0(InputStream fileData, long dataLength, String remoteFile, String mode, boolean sudo)
      throws IOException, InterruptedException, SshException {
    int lastSlash = remoteFile.lastIndexOf('/');
    if (lastSlash == -1) {
      throw new IllegalArgumentException("Expected dest file to be absolute path: " + remoteFile);
    }

    MinaSshConnectionWrapper sshConnection = ensureConnected();

    MinaScpClient scp = new MinaScpClient(sshConnection);

    String remoteDir = remoteFile.substring(0, lastSlash);
    String filename = remoteFile.substring(lastSlash + 1);

    try {
      TimeSpan timeout = TimeSpan.FIVE_MINUTES;
      scp.put(fileData, dataLength, filename, remoteDir, mode, timeout, sudo);
    } catch (IOException ioException) {
      throw new SshException("Cannot doing scp on file", ioException);
    } catch (RuntimeSshException e) {
      throw new SshException("Error doing scp on file", e);
    }

  }

  @Override
  protected byte[] sshReadFile0(String remoteFile, boolean sudo) throws IOException, InterruptedException,
      SshException {
    MinaSshConnectionWrapper sshConnection = ensureConnected();

    MinaScpClient scp = new MinaScpClient(sshConnection);
    try {
      ByteArrayOutputStream baos = new ByteArrayOutputStream();
      TimeSpan timeout = TimeSpan.FIVE_MINUTES;
      scp.get(remoteFile, baos, timeout, sudo);
      return baos.toByteArray();
    } catch (IOException ioException) {
      throw new SshException("Cannot read file", ioException);
    } catch (SshException sshException) {
      String message = sshException.getMessage();
      if (message != null && message.endsWith(": No such file or directory")) {
        return null;
      }
      throw sshException;
    } catch (RuntimeSshException e) {
      throw new SshException("Error reading file", e);
    }
  }

  @Override
  public Socket buildTunneledSocket() throws IOException, SshException {
    MinaSshConnectionWrapper sshConnection = ensureConnected();

    return new SshTunnelSocket(sshConnection.sshClientSession);
  }

  @Override
  public SshConnection buildAgentConnection(KeyPair agentKeyPair) throws IOException, SshException {
    LocalAgentFactory agentFactory = new LocalAgentFactory();
    SshAgent agent = agentFactory.getAgent();
    try {
      agent.addIdentity(agentKeyPair, "default");
    } catch (IOException e) {
      throw new IllegalArgumentException("Error adding agent identity", e);
    }
    MinaSshConnection agentConnection = new MinaSshConnection(new MinaSshContext(agentFactory));

    agentConnection.setHost(this.getHost());
    agentConnection.setPort(this.getPort());
    agentConnection.setUser(this.getUser());
    agentConnection.setServerKeyVerifier(this.getServerKeyVerifier());
    agentConnection.setKeyPair(this.getKeyPair());

    agentConnection.ensureConnected();

    return agentConnection;
  }

  @Override
  public SshPortForward forwardLocalPort(InetSocketAddress remoteSocketAddress) throws IOException, SshException {
    MinaSshConnectionWrapper sshConnection = ensureConnected();

    final ForwardLocalPort forwardLocalPort = new ForwardLocalPort(sshConnection.sshClientSession,
        remoteSocketAddress);

    forwardLocalPort.start();

    return new SshPortForward() {
      @Override
      public void close() throws IOException {
        forwardLocalPort.close();
      }

      @Override
      public InetSocketAddress getLocalSocketAddress() {
        return forwardLocalPort.getLocalSocketAddress();
      }
    };
  }

}
TOP

Related Classes of org.platformlayer.ssh.mina.MinaSshConnection

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.