Package org.h2.dev.net

Source Code of org.h2.dev.net.PgTcpRedirect$TcpRedirectThread

/*
* Copyright 2004-2011 H2 Group. Multiple-Licensed under the H2 License,
* Version 1.0, and under the Eclipse Public License, Version 1.0
* (http://h2database.com/html/license.html).
* Initial Developer: H2 Group
*/
package org.h2.dev.net;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.ServerSocket;
import java.net.Socket;

/**
* This class helps debug the PostgreSQL network protocol.
* It listens on one port, and sends the exact same data to another port.
*/
public class PgTcpRedirect {

    private static final boolean DEBUG = false;

    /**
     * This method is called when executing this application from the command
     * line.
     *
     * @param args the command line parameters
     */
    public static void main(String... args) throws Exception {
        new PgTcpRedirect().loop(args);
    }

    private void loop(String... args) throws Exception {
        // MySQL protocol:
        // http://www.redferni.uklinux.net/mysql/MySQL-Protocol.html
        // PostgreSQL protocol:
        // http://developer.postgresql.org/pgdocs/postgres/protocol.html
        // int portServer = 9083, portClient = 9084;
        // int portServer = 3306, portClient = 3307;
        // H2 PgServer
        // int portServer = 5435, portClient = 5433;
        // PostgreSQL
        int portServer = 5432, portClient = 5433;

        for (int i = 0; i < args.length; i++) {
            if ("-client".equals(args[i])) {
                portClient = Integer.parseInt(args[++i]);
            } else if ("-server".equals(args[i])) {
                portServer = Integer.parseInt(args[++i]);
            }
        }
        ServerSocket listener = new ServerSocket(portClient);
        while (true) {
            Socket client = listener.accept();
            Socket server = new Socket("localhost", portServer);
            TcpRedirectThread c = new TcpRedirectThread(client, server, true);
            TcpRedirectThread s = new TcpRedirectThread(server, client, false);
            new Thread(c).start();
            new Thread(s).start();
        }
    }

    /**
     * This is the working thread of the TCP redirector.
     */
    private class TcpRedirectThread implements Runnable {

        private static final int STATE_INIT_CLIENT = 0, STATE_REGULAR = 1;
        private Socket read, write;
        private int state;
        private boolean client;

        TcpRedirectThread(Socket read, Socket write, boolean client) {
            this.read = read;
            this.write = write;
            this.client = client;
            state = client ? STATE_INIT_CLIENT : STATE_REGULAR;
        }

        String readStringNull(InputStream in) throws IOException {
            StringBuilder buff = new StringBuilder();
            while (true) {
                int x = in.read();
                if (x <= 0) {
                    break;
                }
                buff.append((char) x);
            }
            return buff.toString();
        }

        private void println(String s) {
            if (DEBUG) {
                System.out.println(s);
            }
        }

        private boolean processClient(InputStream inStream, OutputStream outStream) throws IOException {
            DataInputStream dataIn = new DataInputStream(inStream);
            ByteArrayOutputStream buff = new ByteArrayOutputStream();
            DataOutputStream dataOut = new DataOutputStream(buff);
            if (state == STATE_INIT_CLIENT) {
                state = STATE_REGULAR;
                int len = dataIn.readInt();
                dataOut.writeInt(len);
                len -= 4;
                byte[] data = new byte[len];
                dataIn.readFully(data, 0, len);
                dataOut.write(data);
                dataIn = new DataInputStream(new ByteArrayInputStream(data, 0, len));
                int version = dataIn.readInt();
                if (version == 80877102) {
                    println("CancelRequest");
                    println(" pid: " + dataIn.readInt());
                    println(" key: " + dataIn.readInt());
                } else if (version == 80877103) {
                    println("SSLRequest");
                } else {
                    println("StartupMessage");
                    println(" version " + version + " (" + (version >> 16) + "." + (version & 0xff) + ")");
                    while (true) {
                        String param = readStringNull(dataIn);
                        if (param.length() == 0) {
                            break;
                        }
                        String value = readStringNull(dataIn);
                        println(" param " + param + "=" + value);
                    }
                }
            } else {
                int x = dataIn.read();
                if (x < 0) {
                    println("end");
                    return false;
                }
                // System.out.println(" x=" + (char)x+" " +x);
                dataOut.write(x);
                int len = dataIn.readInt();
                dataOut.writeInt(len);
                len -= 4;
                byte[] data = new byte[len];
                dataIn.readFully(data, 0, len);
                dataOut.write(data);
                dataIn = new DataInputStream(new ByteArrayInputStream(data, 0, len));
                switch (x) {
                case 'B': {
                    println("Bind");
                    println(" destPortal: " + readStringNull(dataIn));
                    println(" prepName: " + readStringNull(dataIn));
                    int formatCodesCount = dataIn.readShort();
                    for (int i = 0; i < formatCodesCount; i++) {
                        println(" formatCode[" + i + "]=" + dataIn.readShort());
                    }
                    int paramCount = dataIn.readShort();
                    for (int i = 0; i < paramCount; i++) {
                        int paramLen = dataIn.readInt();
                        println(" length[" + i + "]=" + paramLen);
                        byte[] d2 = new byte[paramLen];
                        dataIn.readFully(d2);
                    }
                    int resultCodeCount = dataIn.readShort();
                    for (int i = 0; i < resultCodeCount; i++) {
                        println(" resultCodeCount[" + i + "]=" + dataIn.readShort());
                    }
                    break;
                }
                case 'C': {
                    println("Close");
                    println(" type: (S:prepared statement, P:portal): " + dataIn.read());
                    break;
                }
                case 'd': {
                    println("CopyData");
                    break;
                }
                case 'c': {
                    println("CopyDone");
                    break;
                }
                case 'f': {
                    println("CopyFail");
                    println(" message: " + readStringNull(dataIn));
                    break;
                }
                case 'D': {
                    println("Describe");
                    println(" type (S=prepared statement, P=portal): " + (char) dataIn.readByte());
                    println(" name: " + readStringNull(dataIn));
                    break;
                }
                case 'E': {
                    println("Execute");
                    println(" name: " + readStringNull(dataIn));
                    println(" maxRows: " + dataIn.readShort());
                    break;
                }
                case 'H': {
                    println("Flush");
                    break;
                }
                case 'F': {
                    println("FunctionCall");
                    println(" objectId:" + dataIn.readInt());
                    int columns = dataIn.readShort();
                    for (int i = 0; i < columns; i++) {
                        println(" formatCode[" + i + "]: " + dataIn.readShort());
                    }
                    int count = dataIn.readShort();
                    for (int i = 0; i < count; i++) {
                        int l = dataIn.readInt();
                        println(" len[" + i + "]: " + l);
                        if (l >= 0) {
                            for (int j = 0; j < l; j++) {
                                dataIn.readByte();
                            }
                        }
                    }
                    println(" resultFormat: " + dataIn.readShort());
                    break;
                }
                case 'P': {
                    println("Parse");
                    println(" name:" + readStringNull(dataIn));
                    println(" query:" + readStringNull(dataIn));
                    int count = dataIn.readShort();
                    for (int i = 0; i < count; i++) {
                        println(" [" + i + "]: " + dataIn.readInt());
                    }
                    break;
                }
                case 'p': {
                    println("PasswordMessage");
                    println(" password: " + readStringNull(dataIn));
                    break;
                }
                case 'Q': {
                    println("Query");
                    println(" sql : " + readStringNull(dataIn));
                    break;
                }
                case 'S': {
                    println("Sync");
                    break;
                }
                case 'X': {
                    println("Terminate");
                    break;
                }
                default:
                    println("############## UNSUPPORTED: " + (char) x);
                }
            }
            dataOut.flush();
            byte[] buffer = buff.toByteArray();
            printData(buffer, buffer.length);
            try {
                outStream.write(buffer, 0, buffer.length);
                outStream.flush();
            } catch (IOException e) {
                e.printStackTrace();
            }
            return true;
        }

        private boolean processServer(InputStream inStream, OutputStream outStream) throws IOException {
            DataInputStream dataIn = new DataInputStream(inStream);
            ByteArrayOutputStream buff = new ByteArrayOutputStream();
            DataOutputStream dataOut = new DataOutputStream(buff);
            int x = dataIn.read();
            if (x < 0) {
                println("end");
                return false;
            }
            // System.out.println(" x=" + (char)x+" " +x);
            dataOut.write(x);
            int len = dataIn.readInt();
            dataOut.writeInt(len);
            len -= 4;
            byte[] data = new byte[len];
            dataIn.readFully(data, 0, len);
            dataOut.write(data);
            dataIn = new DataInputStream(new ByteArrayInputStream(data, 0, len));
            switch (x) {
            case 'R': {
                println("Authentication");
                int value = dataIn.readInt();
                if (value == 0) {
                    println(" Ok");
                } else if (value == 2) {
                    println(" KerberosV5");
                } else if (value == 3) {
                    println(" CleartextPassword");
                } else if (value == 4) {
                    println(" CryptPassword");
                    byte b1 = dataIn.readByte();
                    byte b2 = dataIn.readByte();
                    println(" salt1=" + b1 + " salt2=" + b2);
                } else if (value == 5) {
                    println(" MD5Password");
                    byte b1 = dataIn.readByte();
                    byte b2 = dataIn.readByte();
                    byte b3 = dataIn.readByte();
                    byte b4 = dataIn.readByte();
                    println(" salt1=" + b1 + " salt2=" + b2 + " 3=" + b3 + " 4=" + b4);
                } else if (value == 6) {
                    println(" SCMCredential");
                }
                break;
            }
            case 'K': {
                println("BackendKeyData");
                println(" process ID " + dataIn.readInt());
                println(" key " + dataIn.readInt());
                break;
            }
            case '2': {
                println("BindComplete");
                break;
            }
            case '3': {
                println("CloseComplete");
                break;
            }
            case 'C': {
                println("CommandComplete");
                println(" command tag: " + readStringNull(dataIn));
                break;
            }
            case 'd': {
                println("CopyData");
                break;
            }
            case 'c': {
                println("CopyDone");
                break;
            }
            case 'G': {
                println("CopyInResponse");
                println(" format: " + dataIn.readByte());
                int columns = dataIn.readShort();
                for (int i = 0; i < columns; i++) {
                    println(" formatCode[" + i + "]: " + dataIn.readShort());
                }
                break;
            }
            case 'H': {
                println("CopyOutResponse");
                println(" format: " + dataIn.readByte());
                int columns = dataIn.readShort();
                for (int i = 0; i < columns; i++) {
                    println(" formatCode[" + i + "]: " + dataIn.readShort());
                }
                break;
            }
            case 'D': {
                println("DataRow");
                int columns = dataIn.readShort();
                println(" columns : " + columns);
                for (int i = 0; i < columns; i++) {
                    int l = dataIn.readInt();
                    if (l > 0) {
                        for (int j = 0; j < l; j++) {
                            dataIn.readByte();
                        }
                    }
                    // println(" ["+i+"] len: " + l);
                }
                break;
            }
            case 'I': {
                println("EmptyQueryResponse");
                break;
            }
            case 'E': {
                println("ErrorResponse");
                while (true) {
                    int fieldType = dataIn.readByte();
                    if (fieldType == 0) {
                        break;
                    }
                    String msg = readStringNull(dataIn);
                    // http://developer.postgresql.org/pgdocs/postgres/protocol-error-fields.html
                    // S Severity
                    // C Code: the SQLSTATE code
                    // M Message
                    // D Detail
                    // H Hint
                    // P Position
                    // p Internal position
                    // q Internal query
                    // W Where
                    // F File
                    // L Line
                    // R Routine
                    println(" fieldType: " + fieldType + " msg: " + msg);
                }
                break;
            }
            case 'V': {
                println("FunctionCallResponse");
                int resultLen = dataIn.readInt();
                println(" len: " + resultLen);
                break;
            }
            case 'n': {
                println("NoData");
                break;
            }
            case 'N': {
                println("NoticeResponse");
                while (true) {
                    int fieldType = dataIn.readByte();
                    if (fieldType == 0) {
                        break;
                    }
                    String msg = readStringNull(dataIn);
                    // http://developer.postgresql.org/pgdocs/postgres/protocol-error-fields.html
                    // S Severity
                    // C Code: the SQLSTATE code
                    // M Message
                    // D Detail
                    // H Hint
                    // P Position
                    // p Internal position
                    // q Internal query
                    // W Where
                    // F File
                    // L Line
                    // R Routine
                    println(" fieldType: " + fieldType + " msg: " + msg);
                }
                break;
            }
            case 'A': {
                println("NotificationResponse");
                println(" processID: " + dataIn.readInt());
                println(" condition: " + readStringNull(dataIn));
                println(" information: " + readStringNull(dataIn));
                break;
            }
            case 't': {
                println("ParameterDescription");
                println(" processID: " + dataIn.readInt());
                int count = dataIn.readShort();
                for (int i = 0; i < count; i++) {
                    println(" [" + i + "] objectId: " + dataIn.readInt());
                }
                break;
            }
            case 'S': {
                println("ParameterStatus");
                println(" parameter " + readStringNull(dataIn) + " = " + readStringNull(dataIn));
                break;
            }
            case '1': {
                println("ParseComplete");
                break;
            }
            case 's': {
                println("ParseComplete");
                break;
            }
            case 'Z': {
                println("ReadyForQuery");
                println(" status (I:idle, T:transaction, E:failed): " + (char) dataIn.readByte());
                break;
            }
            case 'T': {
                println("RowDescription");
                int columns = dataIn.readShort();
                println(" columns : " + columns);
                for (int i = 0; i < columns; i++) {
                    println(" [" + i + "]");
                    println("  name:" + readStringNull(dataIn));
                    println("  tableId:" + dataIn.readInt());
                    println("  columnId:" + dataIn.readShort());
                    println("  dataTypeId:" + dataIn.readInt());
                    println("  dataTypeSize (pg_type.typlen):" + dataIn.readShort());
                    println("  modifier (pg_attribute.atttypmod):" + dataIn.readInt());
                    println("  format code:" + dataIn.readShort());
                }
                break;
            }
            default:
                println("############## UNSUPPORTED: " + (char) x);
            }
            dataOut.flush();
            byte[] buffer = buff.toByteArray();
            printData(buffer, buffer.length);
            try {
                outStream.write(buffer, 0, buffer.length);
                outStream.flush();
            } catch (IOException e) {
                e.printStackTrace();
            }
            return true;
        }

        public void run() {
            try {
                OutputStream out = write.getOutputStream();
                InputStream in = read.getInputStream();
                while (true) {
                    boolean more;
                    if (client) {
                        more = processClient(in, out);
                    } else {
                        more = processServer(in, out);
                    }
                    if (!more) {
                        break;
                    }
                }
                try {
                    read.close();
                } catch (IOException e) {
                    // ignore
                }
                try {
                    write.close();
                } catch (IOException e) {
                    // ignore
                }
            } catch (Throwable e) {
                e.printStackTrace();
            }
        }
    }

    /**
     * Print the uninterpreted byte array.
     *
     * @param buffer the byte array
     * @param len the length
     */
    static synchronized void printData(byte[] buffer, int len) {
        if (DEBUG) {
            System.out.print(" ");
            for (int i = 0; i < len; i++) {
                int c = buffer[i] & 255;
                if (c >= ' ' && c <= 127 && c != '[' & c != ']') {
                    System.out.print((char) c);
                } else {
                    System.out.print("[" + Integer.toHexString(c) + "]");
                }
            }
            System.out.println();
        }
    }
}
TOP

Related Classes of org.h2.dev.net.PgTcpRedirect$TcpRedirectThread

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.