Package org.adbcj.mysql.codec

Source Code of org.adbcj.mysql.codec.MySqlClientDecoder$BoundedInputStream

/*
  This file is part of ADBCJ.

  ADBCJ is free software: you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation, either version 3 of the License, or
  (at your option) any later version.

  ADBCJ is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with ADBCJ.  If not, see <http://www.gnu.org/licenses/>.

  Copyright 2008  Mike Heath
*/
package org.adbcj.mysql.codec;
import java.io.IOException;
import java.io.InputStream;
import java.util.Set;

import org.adbcj.Value;
import org.adbcj.support.DefaultValue;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* Client stateful decoder.  Being stateful, each client connection must have its own decoder instance to function
* properly.
*
* @author Mike Heath <mheath@apache.org>
*
*/
public class MySqlClientDecoder {
  private static final Logger logger = LoggerFactory.getLogger(MySqlClientDecoder.class);

  /**
   * The salt size in a server greeting
   */
  public static final int SALT_SIZE = 8;

  /**
   * The size of the second salt in a server greeting
   */
  public static final int SALT2_SIZE = 12;

  /**
   * Number of unused bytes in server greeting
   */
  public static final int GREETING_UNUSED_SIZE = 13;

  public static final int SQL_STATE_LENGTH = 5;

  public static final int RESPONSE_OK = 0x00;
  public static final int RESPONSE_EOF = 0xfe;
  public static final int RESPONSE_ERROR = 0xff;
  /**
   * The state of the decoder.
   */
  enum State {
    /**
     * The client is connecting
     */
    CONNECTING,
    RESPONSE,
    FIELD,
    FIELD_EOF,
    ROW
  }

  private State state = State.CONNECTING;

  private String charset = "ISO8859_1";

  /**
   * Holds the remaining number of field packets expected to build the result set
   */
  private int expectedFieldPackets = 0;

  private int fieldIndex = 0;

  /**
   * The field definitions for the current result set
   */
  private MysqlField[] fields;

  /**
   * Decodes a message from a MySql server.
   *
   * @param input  the {@code InputStream} from which to decode the message
   * @param block  true if the decoder can block, false otherwise
   * @return  the decode message, null if the {@code block} is {@code} false and there is not enough data available
   *         to decode the message without blocking
   * @throws IOException  thrown if an error occurs reading data from the inputstream
   */
  public ServerPacket decode(InputStream input, boolean block) throws IOException {
    // If mark is not support and we can't block, throw an exception
    if (!input.markSupported() && !block) {
      throw new IllegalArgumentException("Non-blocking decoding requires an InputStream that supports marking");
    }
    // TODO This should be the max packet size - make this configurable
    input.mark(Integer.MAX_VALUE);
    ServerPacket message = null;
    try {
      message = doDecode(input, block);
    } finally {
      if (message == null) {
        input.reset();
      }
    }
    return message;
  }

  protected ServerPacket doDecode(InputStream input, boolean block) throws IOException {
    // If we can't block, make sure there's enough data available to read
    if (!block) {
      if (input.available() < 3) {
        return null;
      }
    }
    // Read the packet length
    final int length = IoUtils.readUnsignedMediumInt(input);

    // If we can't block, make sure the stream has enough data
    if (!block) {
      // Make sure we have enough data for the packet length and the packet number
      if (input.available() < length + 1) {
        return null;
      }
    }
    final int packetNumber = IoUtils.safeRead(input);
    BoundedInputStream in = new BoundedInputStream(input, length);
    boolean threwException = false;
    try {

      logger.trace("Decoding in state {}", state);
      switch (state) {
      case CONNECTING:
        ServerGreeting serverGreeting = decodeServerGreeting(in, length, packetNumber);
        state = State.RESPONSE;
        return serverGreeting;
      case RESPONSE:
        int fieldCount = in.read();
        if (fieldCount == RESPONSE_OK) {
          // Create Ok response
          return decodeOkResponse(in, length, packetNumber);
        }
        if (fieldCount == RESPONSE_ERROR) {
          // Create error response
          ErrorResponse response = decodeErrorResponse(in, length, packetNumber);
          return response;
        }
        if (fieldCount == RESPONSE_EOF) {
          throw new IllegalStateException("Did not expect an EOF response from the server");
        }
        // Must be receiving result set header

        // Get the number of fields. The largest this can be is a 24-bit
        // integer so cast to int is ok
        expectedFieldPackets = (int)IoUtils.readBinaryLengthEncoding(in, fieldCount);
        fields = new MysqlField[expectedFieldPackets];
        logger.trace("Field count {}", expectedFieldPackets);

        Long extra = null;
        if (in.getRemaining() > 0) {
          extra = IoUtils.readBinaryLengthEncoding(in);
        }

        state = State.FIELD;

        return new ResultSetResponse(length, packetNumber, expectedFieldPackets, extra);
      case FIELD:
        ResultSetFieldResponse resultSetFieldResponse = decodeFieldResponse(in, length, packetNumber);

        expectedFieldPackets--;
        logger.trace("fieldPacketCount: {}", expectedFieldPackets);
        if (expectedFieldPackets == 0) {
          state = State.FIELD_EOF;
        }
        return resultSetFieldResponse;
      case FIELD_EOF:
        fieldCount = in.read();

        if (fieldCount != RESPONSE_EOF) {
          throw new IllegalStateException("Expected an EOF response from the server");
        }
        EofResponse fieldEof = decodeEofResponse(in, length, packetNumber, EofResponse.Type.FIELD);
        state = State.ROW;
        fieldIndex = 0;
        return fieldEof;
      case ROW:
        fieldCount = in.read(); // This is only for checking for EOF
        if (fieldCount == RESPONSE_EOF) {
          EofResponse rowEof = decodeEofResponse(in, length, packetNumber, EofResponse.Type.ROW);

          state = State.RESPONSE;

          return rowEof;
        }

        Value[] values = new Value[fields.length];
        for (int i = 0; i < fields.length; ) {
          MysqlField field = fields[i++];
          Object value = null;
          if (fieldCount != IoUtils.NULL_VALUE) {
            // We will have to move this as some datatypes will not be sent across the wire as strings
            String strVal = IoUtils.readLengthCodedString(in, fieldCount, charset);

            // TODO add decoding for all column types
            switch (field.getColumnType()) {
            case TINYINT:
              value = Byte.valueOf(strVal);
              break;
            case INTEGER:
            case BIGINT:
              value = Long.valueOf(strVal);
              break;
            case VARCHAR:
              value = strVal;
              break;
            default:
              throw new IllegalStateException("Don't know how to handle column type of "
                  + field.getColumnType());
            }
          }
          values[field.getIndex()] = new DefaultValue(field, value);
          if (i < fields.length) {
            fieldCount = in.read();
          }

        }
        return new ResultSetRowResponse(length, packetNumber, values);
      default:
        throw new IllegalStateException("Unkown decoder state " + state);
      }
    } catch (IOException e) {
      threwException = true;
      throw e;
    } catch (RuntimeException e) {
      threwException = true;
      throw e;
    } finally {
      if (!threwException && in.getRemaining() > 0) {
        throw new IllegalStateException("Buffer underrun occured; remaining bytes: " + in.getRemaining());
      }
    }
  }

  protected ServerGreeting decodeServerGreeting(InputStream in, int length, int packetNumber) throws IOException {
    int protocol = IoUtils.safeRead(in);
    String version = IoUtils.readString(in, "ASCII");
    int threadId = IoUtils.readInt(in);

    byte[] salt = new byte[SALT_SIZE + SALT2_SIZE];
    in.read(salt, 0, SALT_SIZE);
    in.read(); // Throw away 0 byte

    Set<ClientCapabilities> serverCapabilities = IoUtils.readEnumSetShort(in, ClientCapabilities.class);
    MysqlCharacterSet charSet = MysqlCharacterSet.findById(in.read());
    Set<ServerStatus> serverStatus = IoUtils.readEnumSetShort(in, ServerStatus.class);
    in.skip(GREETING_UNUSED_SIZE);

    in.read(salt, SALT_SIZE, SALT2_SIZE);
    in.read(); // Throw away 0 byte

    return new ServerGreeting(length, packetNumber, protocol, version, threadId, salt, serverCapabilities, charSet,
        serverStatus);
  }

  protected OkResponse decodeOkResponse(BoundedInputStream in, int length, int packetNumber) throws IOException {
    long affectedRows = IoUtils.readBinaryLengthEncoding(in);
    long insertId = IoUtils.readBinaryLengthEncoding(in);
    Set<ServerStatus> serverStatus = IoUtils.readEnumSetShort(in, ServerStatus.class);
    int warningCount = IoUtils.readUnsignedShort(in);
    String message = IoUtils.readFixedLengthString(in, in.getRemaining(), charset);

    return new OkResponse(length, packetNumber, affectedRows, insertId, serverStatus,
        warningCount, message);
  }

  protected ErrorResponse decodeErrorResponse(InputStream in, int length, int packetNumber) throws IOException {
    int errorNumber = IoUtils.readUnsignedShort(in);
    in.read(); // Throw away sqlstate marker
    String sqlState = IoUtils.readString(in, "ASCII");
    String message = IoUtils.readString(in, charset);
    return new ErrorResponse(length, packetNumber, errorNumber, sqlState, message);
  }

  protected EofResponse decodeEofResponse(InputStream in, int length, int packetNumber, EofResponse.Type type) throws IOException {
    int warnings = IoUtils.readUnsignedShort(in);
    Set<ServerStatus> serverStatus = IoUtils.readEnumSetShort(in, ServerStatus.class);

    return new EofResponse(length, packetNumber, warnings, serverStatus, type);
  }

  protected ResultSetFieldResponse decodeFieldResponse(InputStream in,
      int packetLength, int packetNumber) throws IOException {
    String catalogName = IoUtils.readLengthCodedString(in, charset);
    String schemaName = IoUtils.readLengthCodedString(in, charset);
    String tableLabel = IoUtils.readLengthCodedString(in, charset);
    String tableName = IoUtils.readLengthCodedString(in, charset);
    String columnLabel = IoUtils.readLengthCodedString(in, charset);
    String columnName = IoUtils.readLengthCodedString(in, charset);
    in.read(); // Skip filler
    int characterSetNumber = IoUtils.readUnsignedShort(in);
    MysqlCharacterSet charSet = MysqlCharacterSet.findById(characterSetNumber);
    long length = IoUtils.readUnsignedInt(in);
    int fieldTypeId = in.read();
    MysqlType fieldType = MysqlType.findById(fieldTypeId);
    Set<FieldFlag> flags = IoUtils.readEnumSet(in, FieldFlag.class);
    int decimals = in.read();
    in.skip(2); // Skip filler
    long fieldDefault = IoUtils.readBinaryLengthEncoding(in);
    MysqlField field = new MysqlField(fieldIndex, catalogName, schemaName, tableLabel, tableName, fieldType, columnLabel,
        columnName, 0, // Figure out precision
        decimals, charSet, length, flags, fieldDefault);
    fields[fieldIndex++] = field;
    return new ResultSetFieldResponse(packetLength, packetNumber, field);
  }

  // TODO: This stream implementation doesn't even work b ecause it doesn't delegate all InputStream methods
  private static class BoundedInputStream extends InputStream {

    private final InputStream in;
    private int remaining;

    public BoundedInputStream(InputStream in, int length) {
      this.in = in;
      this.remaining = length;
    }

    @Override
    public int read() throws IOException {
      int i = in.read();
      if (i >= 0) {
        remaining --;
      }
      if (remaining < 0) {
        throw new IllegalStateException("Buffer overrun");
      }
      return i;
    }

    @Override
    public int read(byte[] b, int off, int len) throws IOException {
      int i = in.read(b, off, len);
      remaining -= i;
      if (remaining < 0) {
        throw new IllegalStateException("Read too many bytes");
      }
      return i;
    }

    @Override
    public long skip(long n) throws IOException {
      long i = in.skip(n);
      remaining -= i;
      if (remaining < 0) {
        throw new IllegalStateException("Read too many bytes");
      }
      return i;
    }

    public int getRemaining() {
      return remaining;
    }

  }

  /**
   * Sets the state, used for testing.
   *
   * @param state
   */
  void setState(State state) {
    this.state = state;
  }
}
TOP

Related Classes of org.adbcj.mysql.codec.MySqlClientDecoder$BoundedInputStream

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.