Package org.platformlayer.ssh.mina

Source Code of org.platformlayer.ssh.mina.MinaSshConnectionWrapper

package org.platformlayer.ssh.mina;

import java.io.Closeable;
import java.io.IOException;
import java.security.KeyPair;

import org.apache.sshd.ClientSession;
import org.apache.sshd.SshClient;
import org.apache.sshd.client.ServerKeyVerifier;
import org.apache.sshd.client.channel.ChannelSession;
import org.apache.sshd.client.future.ConnectFuture;
import org.platformlayer.ExceptionUtils;
import org.platformlayer.ops.ssh.IServerKeyVerifier;
import org.platformlayer.ops.ssh.SshException;
import org.platformlayer.ssh.SshConnectionInfo;
import org.platformlayer.ssh.mina.bugfix.BugFixChannelExec;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.fathomdb.TimeSpan;

public class MinaSshConnectionWrapper implements Closeable {
  private static final Logger log = LoggerFactory.getLogger(MinaSshConnectionWrapper.class);

  final SshConnectionInfo connectionInfo;

  ClientSession sshClientSession;
  SshConnectionState state;

  private final MinaSshContext sshContext;

  private IServerKeyVerifier serverKeyVerifier;

  enum SshConnectionState {
    NotConnected, Connected, Authenticated, Closed
  }

  MinaSshConnectionWrapper(MinaSshContext sshContext, SshConnectionInfo connectionInfo) {
    super();
    this.sshContext = sshContext;
    this.connectionInfo = connectionInfo;
    this.state = SshConnectionState.NotConnected;
  }

  public SshConnectionInfo getConnectionInfo() {
    return connectionInfo;
  }

  @Override
  public void close() {
    if (sshClientSession != null) {
      sshClientSession.close(true);
      sshClientSession = null;
    }
    this.state = SshConnectionState.Closed;
  }

  public boolean isConnected() {
    switch (state) {
    case Authenticated:
    case Connected:
      return true;
    case NotConnected:
      return false;
    case Closed:
      return false;
    default:
      throw new IllegalStateException();
    }
  }

  public boolean isAuthenticationComplete() {
    switch (state) {
    case Authenticated:
      return true;
    case Connected:
      return false;
    case NotConnected:
    case Closed:
      throw new IllegalStateException();
    default:
      throw new IllegalStateException();
    }
  }

  public ChannelSession openSession(String command) throws SshException {
    try {
      boolean useAgentForwarding = sshContext.isAgentForwarding();

      ChannelSession clientChannel = BugFixChannelExec.createExecChannel(sshClientSession, command,
          useAgentForwarding);
      return clientChannel;
    } catch (Exception e) {
      ExceptionUtils.handleInterrupted(e);
      throw new SshException("Error creating channel", e);
    }
  }

  public boolean isConnectionClosed() {
    switch (state) {
    case Authenticated:
    case Connected:
      return false;
    case Closed:
      return true;
    default:
      throw new IllegalStateException();
    }
  }

  public static MinaSshConnectionWrapper build(MinaSshContext sshContext, SshConnectionInfo connectionInfo)
      throws SshException {
    return new MinaSshConnectionWrapper(sshContext, connectionInfo);
  }

  public void connect(TimeSpan connectTimeout) throws SshException {
    if (state != SshConnectionState.NotConnected) {
      throw new IllegalStateException();
    }

    try {
      SshClient client = sshContext.client;

      System.out.println("New SSH connection to " + connectionInfo.getHost());

      ConnectFuture connect = client.connect(connectionInfo.getSocketAddress());
      if (!connect.await(connectTimeout.getTotalMilliseconds())) {
        connect.cancel();
        throw new SshException("Timeout while waiting for SSH connection to " + connectionInfo.getHost());
      }

      this.sshClientSession = connect.getSession();

      if (this.sshClientSession == null) {
        throw new IllegalStateException();
      }

      this.state = SshConnectionState.Connected;
    } catch (Exception e) {
      ExceptionUtils.handleInterrupted(e);
      throw new SshException("Error connecting to SSH server @" + connectionInfo.getSocketAddress(), e);
    }
  }

  public boolean authenticateWithPublicKey(String user, KeyPair key, TimeSpan authenticationTimeout)
      throws SshException {
    if (state != SshConnectionState.Connected) {
      throw new IllegalStateException();
    }

    // DelegatingServerKeyVerifier picks this up
    if (serverKeyVerifier != null) {
      sshClientSession.getMetadataMap().put(ServerKeyVerifier.class,
          new ServerKeyVerifierAdapter(serverKeyVerifier));
    }

    // The 0xfff mask with -1 timeout trick lets us get the current state
    int sshSessionStateCode = sshClientSession.waitFor(0xfff, -1);

    if ((sshSessionStateCode & ClientSession.AUTHED) == 0) {
      try {
        sshClientSession.authPublicKey(user, key);
      } catch (IOException e) {
        throw new SshException("I/O error while authenticating " + this, e);
      } catch (IllegalStateException e) {
        throw new SshException("Error authenticating " + this, e);
      }

      sshSessionStateCode = sshClientSession.waitFor(ClientSession.CLOSED | ClientSession.AUTHED
          | ClientSession.WAIT_AUTH, authenticationTimeout.getTotalMilliseconds());
    }

    if ((sshSessionStateCode & ClientSession.WAIT_AUTH) != 0) {
      throw new SshException("SSH authentication failed (trying " + connectionInfo + ")");
    }

    if ((sshSessionStateCode & ClientSession.CLOSED) != 0) {
      this.sshClientSession = null;
      state = SshConnectionState.Closed;
      throw new SshException("SSH connection closed while trying to login");
    }

    // if (serverKeyVerifier != null) {
    // byte[] serverKey = getServerKey(sshClientSession);
    // SocketAddress remoteAddress = new InetSocketAddress(connectionInfo.getHost(), connectionInfo.getPort());
    // if (!serverKeyVerifier.verifyServerKey(remoteAddress, serverKey)) {
    // sshClientSession.close(true);
    // throw new SshException("SSH key verification failed");
    // }
    // }

    if ((sshSessionStateCode & ClientSession.AUTHED) != 0) {
      state = SshConnectionState.Authenticated;
      return true;
    }

    if (sshSessionStateCode == ClientSession.TIMEOUT) {
      throw new SshException("Timeout while waiting for session");
    }
    throw new SshException("Unexpected state; not authenticated; not closed");
  }

  @Override
  public String toString() {
    return "Ssh connection: " + connectionInfo + "; Current state: " + state.toString();
  }

  public IServerKeyVerifier getServerKeyVerifier() {
    return serverKeyVerifier;
  }

  public void setServerKeyVerifier(IServerKeyVerifier serverKeyVerifier) {
    this.serverKeyVerifier = serverKeyVerifier;
  }
}
TOP

Related Classes of org.platformlayer.ssh.mina.MinaSshConnectionWrapper

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.