/* Copyright (c) 2013 RelayRides
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
package com.relayrides.pushy.apns;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.DecoderException;
import io.netty.handler.codec.MessageToByteEncoder;
import io.netty.handler.codec.ReplayingDecoder;
import io.netty.handler.ssl.SslHandler;
import java.nio.ByteBuffer;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.relayrides.pushy.apns.ApnsConnection.ApnsFrameItem;
import com.relayrides.pushy.apns.util.SimpleApnsPushNotification;
public class MockApnsServer {
  private final int port;
  private final NioEventLoopGroup eventLoopGroup;
  private final ArrayList<CountDownLatch> countdownLatches = new ArrayList<CountDownLatch>();
  private Channel channel;
  private boolean shouldSendErrorResponses = true;
  public static final int EXPECTED_TOKEN_SIZE = 32;
  public static final int MAX_PAYLOAD_SIZE = 2048;
  private static final Logger log = LoggerFactory.getLogger(MockApnsServer.class);
  private class ApnsDecoderException extends Exception {
    private static final long serialVersionUID = 1L;
    final int sequenceNumber;
    final RejectedNotificationReason reason;
    public ApnsDecoderException(final int sequenceNumber, final RejectedNotificationReason reason) {
      this.sequenceNumber = sequenceNumber;
      this.reason = reason;
    }
  };
  private enum ApnsPushNotificationDecoderState {
    OPCODE,
    FRAME_LENGTH,
    FRAME;
  }
  private class ApnsPushNotificationDecoder extends ReplayingDecoder<ApnsPushNotificationDecoderState> {
    private int sequenceNumber;
    private Date deliveryInvalidation;
    private byte[] token;
    private byte[] payloadBytes;
    private DeliveryPriority priority;
    private byte[] frame;
    private boolean hasReceivedDeliveryInvalidationTime;
    private boolean hasReceivedSequenceNumber;
    private static final byte BINARY_NOTIFICATION_OPCODE = 2;
    public ApnsPushNotificationDecoder() {
      super(ApnsPushNotificationDecoderState.OPCODE);
    }
    @Override
    protected void decode(final ChannelHandlerContext context, final ByteBuf in, final List<Object> out) throws ApnsDecoderException {
      switch (this.state()) {
        case OPCODE: {
          this.sequenceNumber = 0;
          this.deliveryInvalidation = null;
          this.token = null;
          this.payloadBytes = null;
          this.priority = null;
          this.frame = null;
          this.hasReceivedDeliveryInvalidationTime = false;
          this.hasReceivedSequenceNumber = false;
          final byte opcode = in.readByte();
          if (opcode == BINARY_NOTIFICATION_OPCODE) {
            this.checkpoint(ApnsPushNotificationDecoderState.FRAME_LENGTH);
          } else {
            throw new ApnsDecoderException(this.sequenceNumber, RejectedNotificationReason.UNKNOWN);
          }
          break;
        }
        case FRAME_LENGTH: {
          final int frameSize = in.readInt();
          if (frameSize < 1) {
            throw new ApnsDecoderException(this.sequenceNumber, RejectedNotificationReason.UNKNOWN);
          }
          this.frame = new byte[frameSize];
          this.checkpoint(ApnsPushNotificationDecoderState.FRAME);
          break;
        }
        case FRAME: {
          in.readBytes(this.frame);
          out.add(this.decodeNotificationFromFrame(this.frame));
          this.checkpoint(ApnsPushNotificationDecoderState.OPCODE);
          break;
        }
      }
    }
    private SendableApnsPushNotification<SimpleApnsPushNotification> decodeNotificationFromFrame(final byte[] frame) throws ApnsDecoderException {
      final ByteBuffer buffer = ByteBuffer.wrap(frame);
      while (buffer.hasRemaining()) {
        try {
          final ApnsFrameItem item = ApnsFrameItem.getFrameItemFromCode(buffer.get());
          final short itemLength = buffer.getShort();
          switch (item) {
            case DEVICE_TOKEN: {
              if (this.token != null) {
                throw new ApnsDecoderException(this.sequenceNumber, RejectedNotificationReason.UNKNOWN);
              }
              this.token = new byte[itemLength];
              if (this.token.length == 0) {
                throw new ApnsDecoderException(this.sequenceNumber, RejectedNotificationReason.MISSING_TOKEN);
              } else if (this.token.length != EXPECTED_TOKEN_SIZE) {
                throw new ApnsDecoderException(this.sequenceNumber, RejectedNotificationReason.INVALID_TOKEN_SIZE);
              }
              buffer.get(this.token);
              break;
            }
            case DELIVERY_INVALIDATION_TIME: {
              if (this.hasReceivedDeliveryInvalidationTime) {
                throw new ApnsDecoderException(this.sequenceNumber, RejectedNotificationReason.UNKNOWN);
              }
              final long timestamp = (buffer.getInt() & 0xFFFFFFFFL) * 1000L;
              this.deliveryInvalidation = timestamp > 0 ? new Date(timestamp) : null;
              this.hasReceivedDeliveryInvalidationTime = true;
              break;
            }
            case PAYLOAD: {
              if (this.payloadBytes != null) {
                throw new ApnsDecoderException(this.sequenceNumber, RejectedNotificationReason.UNKNOWN);
              }
              this.payloadBytes = new byte[itemLength];
              if (this.payloadBytes.length == 0) {
                throw new ApnsDecoderException(this.sequenceNumber, RejectedNotificationReason.MISSING_PAYLOAD);
              } else if (this.payloadBytes.length > MAX_PAYLOAD_SIZE) {
                throw new ApnsDecoderException(this.sequenceNumber, RejectedNotificationReason.INVALID_PAYLOAD_SIZE);
              }
              buffer.get(this.payloadBytes);
              break;
            }
            case PRIORITY: {
              if (this.priority != null) {
                throw new ApnsDecoderException(this.sequenceNumber, RejectedNotificationReason.UNKNOWN);
              }
              this.priority = DeliveryPriority.getFromCode(buffer.get());
              break;
            }
            case SEQUENCE_NUMBER: {
              if (this.hasReceivedSequenceNumber) {
                throw new ApnsDecoderException(this.sequenceNumber, RejectedNotificationReason.UNKNOWN);
              }
              this.sequenceNumber = buffer.getInt();
              this.hasReceivedSequenceNumber = true;
              break;
            }
          }
        } catch (final RuntimeException e) {
          throw new ApnsDecoderException(this.sequenceNumber, RejectedNotificationReason.UNKNOWN);
        }
      }
      return this.constructPushNotification();
    }
    private SendableApnsPushNotification<SimpleApnsPushNotification> constructPushNotification() throws ApnsDecoderException {
      if (!this.hasReceivedSequenceNumber || !this.hasReceivedDeliveryInvalidationTime || this.token == null || this.payloadBytes == null || this.priority == null) {
        throw new ApnsDecoderException(this.sequenceNumber, RejectedNotificationReason.UNKNOWN);
      }
      final String payloadString = new String(this.payloadBytes, Charset.forName("UTF-8"));
      return new SendableApnsPushNotification<SimpleApnsPushNotification>(
          new SimpleApnsPushNotification(this.token, payloadString, this.deliveryInvalidation, this.priority),
          this.sequenceNumber);
    }
  }
  private class ApnsErrorEncoder extends MessageToByteEncoder<RejectedNotification> {
    private static final byte ERROR_COMMAND = 8;
    @Override
    protected void encode(final ChannelHandlerContext context, final RejectedNotification rejectedNotification, final ByteBuf out) {
      out.writeByte(ERROR_COMMAND);
      out.writeByte(rejectedNotification.getReason().getErrorCode());
      out.writeInt(rejectedNotification.getSequenceNumber());
    }
  }
  private class MockApnsServerHandler extends SimpleChannelInboundHandler<SendableApnsPushNotification<SimpleApnsPushNotification>> {
    private final MockApnsServer server;
    private boolean rejectFutureNotifications = false;
    public MockApnsServerHandler(final MockApnsServer server) {
      this.server = server;
    }
    @Override
    protected void channelRead0(final ChannelHandlerContext context, final SendableApnsPushNotification<SimpleApnsPushNotification> receivedNotification) {
      if (!this.rejectFutureNotifications) {
        this.server.acceptNotification(receivedNotification);
      }
    }
    @Override
    public void exceptionCaught(final ChannelHandlerContext context, final Throwable cause) {
      this.rejectFutureNotifications = true;
      if (cause instanceof DecoderException) {
        final DecoderException decoderException = (DecoderException)cause;
        if (decoderException.getCause() instanceof ApnsDecoderException) {
          if (this.server.shouldSendErrorResponses()) {
            final ApnsDecoderException apnsDecoderException = (ApnsDecoderException)decoderException.getCause();
            final RejectedNotification rejectedNotification =
                new RejectedNotification(apnsDecoderException.sequenceNumber, apnsDecoderException.reason);
            context.writeAndFlush(rejectedNotification).addListener(ChannelFutureListener.CLOSE);
          }
        }
      } else {
        log.warn("Caught an unexpected exception; closing connection.", cause);
        context.close();
      }
    }
  }
  public MockApnsServer(final int port, final NioEventLoopGroup eventLoopGroup) {
    this.port = port;
    this.eventLoopGroup = eventLoopGroup;
  }
  public synchronized void start() throws InterruptedException {
    final ServerBootstrap bootstrap = new ServerBootstrap();
    bootstrap.group(this.eventLoopGroup);
    bootstrap.channel(NioServerSocketChannel.class);
    bootstrap.childOption(ChannelOption.SO_KEEPALIVE, true);
    final MockApnsServer server = this;
    bootstrap.childHandler(new ChannelInitializer<SocketChannel>() {
      @Override
      protected void initChannel(final SocketChannel channel) throws Exception {
        channel.pipeline().addLast("ssl", new SslHandler(SSLTestUtil.createSSLEngineForMockServer()));
        channel.pipeline().addLast("encoder", new ApnsErrorEncoder());
        channel.pipeline().addLast("decoder", new ApnsPushNotificationDecoder());
        channel.pipeline().addLast("handler", new MockApnsServerHandler(server));
      }
    });
    this.channel = bootstrap.bind(this.port).await().channel();
  }
  public synchronized void shutdown() throws InterruptedException {
    if (this.channel != null) {
      this.channel.close().await();
    }
    this.channel = null;
  }
  public void setShouldSendErrorResponses(final boolean shouldSendErrorResponses) {
    this.shouldSendErrorResponses = shouldSendErrorResponses;
  }
  public boolean shouldSendErrorResponses() {
    return this.shouldSendErrorResponses;
  }
  private void acceptNotification(final SendableApnsPushNotification<SimpleApnsPushNotification> receivedNotification) {
    synchronized (this.countdownLatches) {
      for (final CountDownLatch latch : this.countdownLatches) {
        latch.countDown();
      }
    }
  }
  public CountDownLatch getAcceptedNotificationCountDownLatch(final int acceptedNotificationCount) {
    synchronized (this.countdownLatches) {
      final CountDownLatch latch = new CountDownLatch(acceptedNotificationCount);
      this.countdownLatches.add(latch);
      return latch;
    }
  }
}