Package mia.classifier.ch16.server

Source Code of mia.classifier.ch16.server.Server$ServerWatcher

/*
* Copyright 2010 Ted Dunning. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are
* permitted provided that the following conditions are met:
*
*    1. Redistributions of source code must retain the above copyright notice, this list of
*       conditions and the following disclaimer.
*
*    2. Redistributions in binary form must reproduce the above copyright notice, this list
*       of conditions and the following disclaimer in the documentation and/or other materials
*       provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY <COPYRIGHT HOLDER> ``AS IS'' AND ANY EXPRESS OR IMPLIED
* WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL <COPYRIGHT HOLDER> OR
* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
* ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
* NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
* ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
* The views and conclusions contained in the software and documentation are those of the
* authors and should not be interpreted as representing official policies, either expressed
* or implied, of <copyright holder>.
*/

package mia.classifier.ch16.server;

import com.google.common.base.Charsets;
import mia.classifier.ch16.generated.Classifier;
import org.apache.mahout.classifier.AbstractVectorClassifier;
import org.apache.mahout.classifier.sgd.ModelSerializer;
import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
import org.apache.thrift.protocol.TBinaryProtocol;
import org.apache.thrift.protocol.TProtocolFactory;
import org.apache.thrift.server.TServer;
import org.apache.thrift.server.TThreadPoolServer;
import org.apache.thrift.transport.TServerSocket;
import org.apache.thrift.transport.TTransportException;
import org.apache.zookeeper.CreateMode;
import org.apache.zookeeper.KeeperException;
import org.apache.zookeeper.WatchedEvent;
import org.apache.zookeeper.Watcher;
import org.apache.zookeeper.ZooDefs;
import org.apache.zookeeper.ZooKeeper;
import org.apache.zookeeper.data.Stat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.net.InetAddress;
import java.net.URL;
import java.net.UnknownHostException;
import java.util.Timer;
import java.util.TimerTask;

/**
* Basic classification server. This server watches a Zookeeper cluster to
* determine what models to load and what models to serve.
* <p/>
* The structure of data in ZK is as follows
* <p/>
*
* <pre>
* /model-service/
*   current-servers/        Contains one file per live server.
*   model-to-serve          Contains URL of live model.  Reread on changes.
* </pre>
*/
public class Server {
  public static final String ZK_BASE = "/model-service";
  public static final String ZK_CURRENT_SERVERS = ZK_BASE
      + "/current-servers";
  public static final String ZK_MODEL = ZK_BASE + "/model-to-serve";

  private final TServer server;
  private final Logger log = LoggerFactory.getLogger(this.getClass());

  private Timer timer;

  private ZooKeeper zk;

  private ServerWatcher modelWatcher = new ServerWatcher();

  public Server(int port) throws TTransportException, IOException,
      InterruptedException, KeeperException {
    zk = new ZooKeeper("localhost", 2181, new Watcher() {
      @Override
      public void process(WatchedEvent watchedEvent) {
        // ignore
      }
    });

    if (zk.exists(ZK_BASE, null) == null) {
      log.warn("Creating " + ZK_BASE);
      zk.create(ZK_BASE, new byte[0], ZooDefs.Ids.OPEN_ACL_UNSAFE,
          CreateMode.PERSISTENT);
    }
    if (zk.exists(ZK_CURRENT_SERVERS, null) == null) {
      log.warn("Creating " + ZK_CURRENT_SERVERS);
      zk.create(ZK_CURRENT_SERVERS, new byte[0],
          ZooDefs.Ids.OPEN_ACL_UNSAFE, CreateMode.PERSISTENT);
    }
    zk.close();

    Ops modelHandler = new Ops();
    modelWatcher.setModelHandler(modelHandler);

    // schedule a retry every thirty seconds in case we can't reset the
    // watch
    timer = new Timer();
    timer.scheduleAtFixedRate(new TimerTask() {
      @Override
      public void run() {
        modelWatcher.process(null);
      }
    }, 0, 3000);

    try {
      TServerSocket socket = new TServerSocket(port);
      Classifier.Processor processor = new Classifier.Processor(
          modelHandler);

      TProtocolFactory protocol = new TBinaryProtocol.Factory(true, true);
      server = new TThreadPoolServer(
          new TThreadPoolServer.Args(socket).processor(processor));

      log.warn("Starting server on port {}", port);
      server.serve();
    } finally {
      timer.cancel();
      modelWatcher.close();
    }
  }

  public void close() throws InterruptedException {
    log.warn("Exiting");
    server.stop();
    timer.cancel();
    zk.close();
  }

  public static void main(String[] args) throws IOException,
      TTransportException, InterruptedException, KeeperException {
    new Server(7908);
  }

  private static class ServerWatcher implements Watcher {
    private final Logger log = LoggerFactory.getLogger(this.getClass());

    private Ops modelHandler;

    private String currentUrl = null;
    private int version;

    private ZooKeeper zk = null;
    private String hostname;

    private ServerWatcher() {
      hostname = null;
      try {
        hostname = InetAddress.getLocalHost().getHostName();
      } catch (UnknownHostException e) {
        // continue with null hostname
      }
      if (hostname == null) {
        log.error("Must have hostname ... exiting");
        System.exit(1);
      }
    }

    /**
     * Loads or reloads the model by looking at ZK to get the model URL,
     * then loads that URL to get the serialized model.
     *
     * @param watchedEvent
     *            Ignored.
     */
    @Override
    public void process(WatchedEvent watchedEvent) {
      if (zk == null) {
        try {
          zk = new ZooKeeper("localhost", 2181, null);
        } catch (IOException e) {
          zk = null;
          return;
        }
      }

      String url = null;
      try {
        // get new URL
        Stat stat = new Stat();
        byte[] urlAsBytes = zk.getData(ZK_MODEL, this, stat);
        int latestVersion = stat.getVersion();

        url = new String(urlAsBytes, Charsets.UTF_8);

        // check for change
        URL modelUrl = new URL(url);
        boolean needUpdate = false;
        if (currentUrl == null || latestVersion != version) {
          log.warn("Loading model from " + modelUrl);

          AbstractVectorClassifier model = ModelSerializer
              .readBinary(modelUrl.openStream(),
                  OnlineLogisticRegression.class);

          modelHandler.setModel(model);
          currentUrl = url;
          version = latestVersion;
          log.info("done loading version " + version);
          needUpdate = true;
        }

        // update status file so clients find us
        String statusFile = ZK_CURRENT_SERVERS + "/" + hostname;
        // Tell ZK what model we loaded. We try to do this often because
        // we might have previously
        // updated a lingering ephemeral file belonging to a previous
        // incarnation. After
        // a short time, that ephemeral may disappear and we would need
        // to restore it
        try {
          zk.create(statusFile,
              modelUrl.toString().getBytes(Charsets.UTF_8),
              ZooDefs.Ids.OPEN_ACL_UNSAFE, CreateMode.EPHEMERAL);
          log.info("created server file {}", statusFile);
        } catch (KeeperException.NodeExistsException e) {
          if (needUpdate) {
            zk.setData(statusFile,
                modelUrl.toString().getBytes(Charsets.UTF_8),
                -1);
            log.info("updated server file {}", statusFile);
          }
        } catch (KeeperException e) {
          log.error("Couldn't write server status file");
        }

        return;
      } catch (KeeperException.NoNodeException e) {
        // if no such data on ZK, log it and continue.
        log.error("Could not find model URL in ZK file: " + ZK_MODEL, e);
        return;
      } catch (KeeperException.SessionExpiredException e) {
        log.error("Session expired", e);
        zk = null;
      } catch (KeeperException e) {
        log.error("Failed to load model due to ZK exception", e);
      } catch (InterruptedException e) {
        log.error("Operation interrupted should never happen", e);
      } catch (IOException e) {
        log.error("Failed to load model from " + url, e);
      }

      // only get here on error
      log.warn("Clearing current URL due to error");
      currentUrl = null;
      version = -1;
    }

    public void setModelHandler(Ops modelHandler) {
      this.modelHandler = modelHandler;
    }

    public void close() throws InterruptedException {
      zk.close();
    }
  }
}
TOP

Related Classes of mia.classifier.ch16.server.Server$ServerWatcher

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.