Package ivory.smrf.retrieval

Source Code of ivory.smrf.retrieval.BatchQueryRunner

/*
* Ivory: A Hadoop toolkit for web-scale information retrieval
*
* Licensed under the Apache License, Version 2.0 (the "License"); you
* may not use this file except in compliance with the License. You may
* obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package ivory.smrf.retrieval;

import ivory.core.ConfigurationException;
import ivory.core.RetrievalEnvironment;
import ivory.core.util.ResultWriter;
import ivory.core.util.XMLTools;
import ivory.smrf.model.builder.MRFBuilder;
import ivory.smrf.model.expander.MRFExpander;
import ivory.smrf.model.importance.ConceptImportanceModel;

import java.io.IOException;
import java.util.Map;
import java.util.Set;

import javax.xml.parsers.DocumentBuilderFactory;
import javax.xml.parsers.ParserConfigurationException;

import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.log4j.Logger;
import org.w3c.dom.Document;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;
import org.xml.sax.SAXException;

import com.google.common.base.Preconditions;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;

import edu.umd.cloud9.collection.DocnoMapping;

public class BatchQueryRunner {
  private static final Logger LOG = Logger.getLogger(BatchQueryRunner.class);

  protected DocnoMapping docnoMapping = null;
  protected FileSystem fs = null;
  protected String indexPath = null;
  protected RetrievalEnvironment env = null;

  protected final Map<String, String> queries = Maps.newLinkedHashMap();
  protected final Map<String, Node> importanceModels = Maps.newLinkedHashMap();
  protected final Map<String, Node> models = Maps.newLinkedHashMap();
  protected final Map<String, Node> docscores = Maps.newLinkedHashMap();
  protected final Map<String, Node> expanders = Maps.newHashMap();
  protected final Map<String, Map<String, Double>> judgments = Maps.newHashMap();
  protected final Map<String, QueryRunner> queryRunners = Maps.newLinkedHashMap();
  protected final Set<String> stopwords = Sets.newHashSet();

  public BatchQueryRunner(String[] args, FileSystem fs) throws ConfigurationException {
    init(args, fs);
  }

  public BatchQueryRunner(String[] args, FileSystem fs, String indexPath)
      throws ConfigurationException {
    this.indexPath = indexPath;
    init(args, fs);
  }

  public void init(String[] args, FileSystem fs) throws ConfigurationException {
    this.fs = Preconditions.checkNotNull(fs);
    Preconditions.checkNotNull(args);

    parseParameters(args);
    loadRetrievalEnv();
    try {
      docnoMapping = RetrievalEnvironment.loadDocnoMapping(indexPath, fs);
    } catch (IOException e) {
      throw new ConfigurationException("Failed to load Docnomapping: " + e.getMessage());
    }

    // Load static concept importance models
    for (Map.Entry<String, Node> n : importanceModels.entrySet()) {
      ConceptImportanceModel m = ConceptImportanceModel.get(n.getValue());
      env.addImportanceModel(n.getKey(), m);
    }

    // Load static docscores (e.g., spam score, PageRank, etc.).
    for (Map.Entry<String, Node> n : docscores.entrySet()) {
      String type = XMLTools.getAttributeValue(n.getValue(), "type", "");
      String provider = XMLTools.getAttributeValue(n.getValue(), "provider", "");
      String path = n.getValue().getTextContent();

      if (type.equals("") || provider.equals("") || path.equals("")) {
        throw new ConfigurationException("Invalid docscore!");
      }

      LOG.info("Loading docscore: type=" + type + ", provider=" + provider + ", path=" + path);
      env.loadDocScore(type, provider, path);
    }
  }

  protected void loadRetrievalEnv() throws ConfigurationException {
    try {
      env = new RetrievalEnvironment(indexPath, fs);
      env.initialize(true);
    } catch (IOException e) {
      throw new ConfigurationException("Failed to instantiate RetrievalEnvironment: "
          + e.getMessage());
    }
  }

  public void runQueries() {
    for (String modelID : models.keySet()) {
      Node modelNode = models.get(modelID);
      Node expanderNode = expanders.get(modelID);

      // Initialize retrieval environment variables.
      QueryRunner runner = null;
      MRFBuilder builder = null;
      MRFExpander expander = null;
      try {
        // Get the MRF builder.
        builder = MRFBuilder.get(env, modelNode.cloneNode(true));

        // Get the MRF expander.
        expander = null;
        if (expanderNode != null) {
          expander = MRFExpander.getExpander(env, expanderNode.cloneNode(true));
        }
        if (stopwords != null && stopwords.size() != 0) {
          expander.setStopwordList(stopwords);
        }

        int numHits = XMLTools.getAttributeValue(modelNode, "hits", 1000);

        LOG.info("number of hits: " + numHits);

        // Multi-threaded query evaluation still a bit unstable; setting
        // thread=1 for now.
        runner = new ThreadedQueryRunner(builder, expander, 1, numHits);

        queryRunners.put(modelID, runner);
      } catch (Exception e) {
        e.printStackTrace();
      }

      for (String queryID : queries.keySet()) {
        String rawQueryText = queries.get(queryID);
        String[] queryTokens = env.tokenize(rawQueryText);

        LOG.info(String.format("query id: %s, query: \"%s\"", queryID, rawQueryText));

        // Execute the query.
        runner.runQuery(queryID, queryTokens);
      }

      // Where should we output these results?
      Node model = models.get(modelID);
      String fileName = XMLTools.getAttributeValue(model, "output", null);
      boolean compress = XMLTools.getAttributeValue(model, "compress", false);

      try {
        ResultWriter resultWriter = new ResultWriter(fileName, compress, fs);
        printResults(modelID, runner, resultWriter);
        resultWriter.flush();
      } catch (IOException e) {
        throw new RuntimeException("Error: Unable to write results!");
      }
    }
  }

  public Set<String> getModels() {
    return models.keySet();
  }

  public Node getModel(String modelName) {
    return models.get(modelName);
  }

  public RetrievalEnvironment getRetrievalEnvironment() {
    return env;
  }

  public Map<String, String> getQueries() {
    return queries;
  }

  public Map<String, Double> getJudgmentSet(String qid) {
    return judgments.get(qid);
  }

  public Map<String, Accumulator[]> getResults(String model) {
    return queryRunners.get(model).getResults();
  }

  public Map<String, Map<String, Accumulator[]>> getAllResults() {
    Map<String, Map<String, Accumulator[]>> results = Maps.newHashMap();
    for (String model : getModels()) {
      results.put(model, getResults(model));
    }
    return results;
  }

  public DocnoMapping getDocnoMapping() {
    return docnoMapping;
  }

  private void printResults(String modelID, QueryRunner runner, ResultWriter resultWriter)
      throws IOException {

    for (String queryID : queries.keySet()) {
      // Get the ranked list for this query.
      Accumulator[] list = runner.getResults(queryID);
      if (list == null) {
        LOG.info("null results for: " + queryID);
        continue;
      }

      if (docnoMapping == null) {
        // Print results with internal docnos if unable to translate to
        // external docids.
        for (int i = 0; i < list.length; i++) {
          resultWriter.println(queryID + " Q0 " + list[i].docno + " " + (i + 1) + " "
              + list[i].score + " " + modelID);
        }
      } else {
        // Translate internal docnos to external docids.
        for (int i = 0; i < list.length; i++) {
          resultWriter.println(queryID + " Q0 " + docnoMapping.getDocid(list[i].docno)
              + " " + (i + 1) + " " + list[i].score + " " + modelID);
        }
      }
    }
  }

  private void parseParameters(String[] args) throws ConfigurationException {
    for (int i = 0; i < args.length; i++) {
      String element = args[i];
      Document d = null;

      try {
        d = DocumentBuilderFactory.newInstance().newDocumentBuilder().parse(
              fs.open(new Path(element)));
      } catch (SAXException e) {
        throw new ConfigurationException(e.getMessage());
      } catch (IOException e) {
        throw new ConfigurationException(e.getMessage());
      } catch (ParserConfigurationException e) {
        throw new ConfigurationException(e.getMessage());
      }

      parseQueries(d);
      parseImportanceModels(d);
      parseModels(d);
      parseStopwords(d);
      parseIndexLocation(d);
      parseDocscores(d);
      parseJudgments(d);
    }

    // Make sure we have some queries to run.
    if (queries.isEmpty()) {
      throw new ConfigurationException("Must specify at least one query!");
    }

    // Make sure there are models that need evaluated.
    if (models.isEmpty()) {
      throw new ConfigurationException("Must specify at least one model!");
    }

    // Make sure we have an index to run against.
    if (indexPath == null) {
      throw new ConfigurationException("Must specify an index!");
    }
  }

  private void parseQueries(Document d) throws ConfigurationException {
    NodeList queryNodes = d.getElementsByTagName("query");

    for (int i = 0; i < queryNodes.getLength(); i++) {
      // Get query XML node.
      Node node = queryNodes.item(i);

      // Get query id.
      String qid = XMLTools.getAttributeValueOrThrowException(node, "id",
            "Must specify a query id attribute for every query!");

      // Get query text.
      String queryText = node.getTextContent();

      // Add query to internal map.
      if (queries.get(qid) != null) {
        throw new ConfigurationException(
            "Duplicate query ids not allowed! Already parsed query with id=" + qid);
      }
      queries.put(qid, queryText);
    }
  }

  private void parseImportanceModels(Document d) throws ConfigurationException {
    NodeList models = d.getElementsByTagName("importancemodel");

    for (int i = 0; i < models.getLength(); i++) {
      Node node = models.item(i);

      String id = XMLTools.getAttributeValueOrThrowException(node, "id",
          "Must specify an id for every importancemodel!");

      // Add model to internal map.
      if (importanceModels.get(id) != null) {
        throw new ConfigurationException(
            "Duplicate importancemodel ids not allowed! Already parsed model with id=" + id);
      }
      importanceModels.put(id, node);
    }
  }

  private void parseModels(Document d) throws ConfigurationException {
    NodeList modelList = d.getElementsByTagName("model");

    for (int i = 0; i < modelList.getLength(); i++) {
      // Get model XML node.
      Node node = modelList.item(i);

      // Get model id.
      String id = XMLTools.getAttributeValueOrThrowException(node, "id",
          "Must specify a model id for every model!");

      // Parse parent nodes.
      NodeList children = node.getChildNodes();
      for (int j = 0; j < children.getLength(); j++) {
        Node child = children.item(j);
        if ("expander".equals(child.getNodeName())) {
          if (expanders.containsKey(id)) {
            throw new ConfigurationException("Only one expander allowed per model!");
          }
          expanders.put(id, child);
        }
      }

      // Add model to internal map.
      if (models.get(id) != null) {
        throw new ConfigurationException(
            "Duplicate model ids not allowed! Already parsed model with id=" + id);
      }
      models.put(id, node);
    }
  }

  private void parseStopwords(Document d) {
    NodeList stopwordsList = d.getElementsByTagName("stopword");

    for (int i = 0; i < stopwordsList.getLength(); i++) {
      Node node = stopwordsList.item(i);
      String stopword = node.getTextContent();
      stopwords.add(stopword);
    }
  }

  private void parseIndexLocation(Document d) throws ConfigurationException {
    NodeList indexList = d.getElementsByTagName("index");

    if (indexList.getLength() > 0) {
      if (indexPath == null) {
        indexPath = indexList.item(0).getTextContent();
      }
    }
  }

  private void parseDocscores(Document d) throws ConfigurationException {
    NodeList docscoresList = d.getElementsByTagName("docscore");

    for (int i = 0; i < docscoresList.getLength(); i++) {
      Node node = docscoresList.item(i);

      String docscoreType = XMLTools.getAttributeValueOrThrowException(node, "type",
          "Must specify a type for every docscore!");

      // Add model to internal map.
      if (docscores.get(docscoreType) != null) {
        throw new ConfigurationException(
            "Duplicate docscore types not allowed! Already parsed model with id="
                + docscoreType);
      }
      docscores.put(docscoreType, node);
    }
  }

  private void parseJudgments(Document d) throws ConfigurationException {
    NodeList judgmentsList = d.getElementsByTagName("judgment");

    for (int i = 0; i < judgmentsList.getLength(); i++) {
      // Get XML node.
      Node node = judgmentsList.item(i);

      // Get query, document, and judgment.
      String qid = XMLTools.getAttributeValueOrThrowException(node, "qid",
          "Each judgment must specify a qid attribute!");
      String docname = XMLTools.getAttributeValueOrThrowException(node, "doc",
          "Each judgment must specify a doc attribute!");
      String grade = XMLTools.getAttributeValueOrThrowException(node, "grade",
          "Each judgment must specify a grade attribute!");

      Map<String, Double> queryJudgments = judgments.get(qid);
      if (queryJudgments == null) {
        queryJudgments = Maps.newHashMap();
        judgments.put(qid, queryJudgments);
      }

      queryJudgments.put(docname, Double.parseDouble(grade));
    }
  }
}
TOP

Related Classes of ivory.smrf.retrieval.BatchQueryRunner

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.