Package ivory.cascade.retrieval

Source Code of ivory.cascade.retrieval.CascadeThreadedQueryRunner

/*
* Ivory: A Hadoop toolkit for web-scale information retrieval
*
* Licensed under the Apache License, Version 2.0 (the "License"); you
* may not use this file except in compliance with the License. You may
* obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package ivory.cascade.retrieval;

import ivory.core.RetrievalEnvironment;
import ivory.core.exception.ConfigurationException;
import ivory.smrf.model.MarkovRandomField;
import ivory.smrf.model.builder.MRFBuilder;
import ivory.smrf.model.expander.MRFExpander;
import ivory.smrf.retrieval.Accumulator;
import ivory.smrf.retrieval.MRFDocumentRanker;

import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import org.apache.log4j.Logger;


import com.google.common.base.Preconditions;
import com.google.common.collect.Maps;

/**
* @author Lidan Wang
*/
public class CascadeThreadedQueryRunner implements CascadeQueryRunner {
  private static final Logger sLogger = Logger.getLogger(CascadeThreadedQueryRunner.class);

  private MRFBuilder mBuilder;
  private MRFExpander mExpander;
  private ExecutorService mThreadPool;
  private Map<String, Future<Accumulator[]>> mQueryResults;
  private int mNumHits;
  private Map<Integer, Float[][]> savedResults_prevStage = Maps.newHashMap(); //for all queries
  private int mK; //K value used in cascade model

  //assume no more than 1000 queries
  private float [] cascadeCostAllQueries = new float[1000];
  private float [] cascadeCostAllQueries_lastStage = new float[1000];

  public CascadeThreadedQueryRunner(MRFBuilder builder, MRFExpander expander, int numThreads,
      int numHits, Map<Integer, Float[][]> savedResults, int K) {
    Preconditions.checkNotNull(builder);

    assert (numThreads > 0);
    assert (numHits > 0);

    mBuilder = builder;
    mExpander = expander;
    mThreadPool = Executors.newFixedThreadPool(numThreads);
    mQueryResults = Maps.newLinkedHashMap();
    mNumHits = numHits;
    savedResults_prevStage = savedResults;
    mK = K;
  }

  /**
   * Runs a query asynchronously. Results can be fetched using
   * {@link getResults}.
   */
  public void runQuery(String qid, String[] query) {
    Preconditions.checkNotNull(qid);
    Preconditions.checkNotNull(query);

    Future<Accumulator[]> future = mThreadPool.submit(new ThreadTask(query, mBuilder,
        mExpander, qid, mNumHits));
    mQueryResults.put(qid, future);
  }

  /**
   * Runs a query synchronously, waiting until completion.
   */
  public Accumulator[] runQuery(String[] query) {
    Preconditions.checkNotNull(query);

    Future<Accumulator[]> future = mThreadPool.submit(new ThreadTask(query, mBuilder,
        mExpander, "query", mNumHits));
    Accumulator[] results = null;
    try {
      results = future.get();
    } catch (Exception e) {
      e.printStackTrace();
    }
    return results;
  }

  /**
   * Fetches the results of a query. If necessary, waits until completion of
   * the query.
   *
   * @param qid
   *            query id
   */
  public Accumulator[] getResults(String qid) {
    try {
      return mQueryResults.get(qid).get();
    } catch (InterruptedException e) {
      e.printStackTrace();
      return null;
    } catch (ExecutionException e) {
      e.printStackTrace();
      return null;
    }
  }

  /**
   * Clears all stored results.
   */
  public void clearResults() {
    mQueryResults.clear();
  }

  /**
   * Returns results of all queries executed.
   */
  public Map<String, Accumulator[]> getResults() {
    Map<String, Accumulator[]> results = Maps.newLinkedHashMap();
    for (Map.Entry<String, Future<Accumulator[]>> e : mQueryResults.entrySet()) {
      try {
        Accumulator[] a = e.getValue().get();
       
        if ( a != null) {
          results.put(e.getKey(), e.getValue().get());
        }
       
      } catch (Exception ex) {
        ex.printStackTrace();
      }
    }
    return results;
  }


  //return cascade cost for all queries
  public float[] getCascadeCostAllQueries(){
    return cascadeCostAllQueries;
  }

  public float[] getCascadeCostAllQueries_lastStage(){
    return cascadeCostAllQueries_lastStage;
  }

  // Thread for running a query.  No need to expose implementation.
  private class ThreadTask implements Callable<Accumulator[]> {
    private String[] mQuery;
    private MRFBuilder mBuilder;
    private MRFExpander mExpander;
    private String mQid;
    private int mNumHits;

    public ThreadTask(String[] query, MRFBuilder builder, MRFExpander expander, String qid, int numHits) {
      mQuery = query;
      mBuilder = builder;
      mExpander = expander;
      mQid = qid;
      mNumHits = numHits;
    }

    public Accumulator[] call() {
      try {
        long startTime;
        long endTime;

        startTime = System.currentTimeMillis();

        // Build the MRF for this query.
        Object r = savedResults_prevStage.get(mQid);
        float [][] savedResults = null; //store docno and score
        if (r!=null){
          savedResults = (float[][]) r;
        }
        MarkovRandomField mrf = mBuilder.buildMRF(mQuery);

        // Run initial query, if necessary.
        Accumulator[] results = null;

        float cascadeCost = -1;
        float cascadeCost_lastStage = -1;

        if (mrf.getCliques().size()==0){
        }

        else{
          if (RetrievalEnvironment.mIsNewModel){

            CascadeEval ranker = new CascadeEval (mrf, mNumHits, mQid, savedResults, mK);

            // Rank the documents using the cascade model.
                                          results = ranker.rank();
            cascadeCost = ranker.getCascadeCost();
          }

          else
            // Retrieve documents using this MRF.
            MRFDocumentRanker ranker = new MRFDocumentRanker(mrf, mNumHits);

            if (mExpander != null) {
              results = ranker.rank();
            }

            // Perform pseudo-relevance feedback, if requested.
            if (mExpander != null) {
              // Get expanded MRF.
              MarkovRandomField expandedMRF = mExpander.getExpandedMRF(mrf, results);

              // Re-rank documents according to expanded MRF.
              ranker = new MRFDocumentRanker(expandedMRF, mNumHits);
            }
 
            // Rank the documents.
            results = ranker.rank();

            //cascadeCost = ranker.getCost();
            cascadeCost = -1; //todo: later
       
          }
        }
        endTime = System.currentTimeMillis();
        sLogger.info("Processed query " + mQid + " in " + (endTime - startTime) + " ms.");

        //This stores the cascade cost for this query using the model represented by modelID
        if (cascadeCost != -1){
          //String key = BatchQueryRunner.model_ID + " "+mQid;
          //BatchQueryRunner.cascadeCosts.put(key, cascadeCost+"");

          cascadeCostAllQueries[Integer.parseInt(mQid)] = cascadeCost;
        }

        if (cascadeCost_lastStage!=-1){
          cascadeCostAllQueries_lastStage[Integer.parseInt(mQid)] = cascadeCost_lastStage;
        }

        return results;
      } catch (ConfigurationException e) {
        e.printStackTrace();
        sLogger.error(e.getMessage());
       
        return null;
      }
    }
  }
}
TOP

Related Classes of ivory.cascade.retrieval.CascadeThreadedQueryRunner

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.