package edu.gslis.ttg.main;
import java.io.FileReader;
import java.io.PrintStream;
import java.io.UnsupportedEncodingException;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.thrift.TException;
import org.json.simple.JSONArray;
import org.json.simple.JSONObject;
import org.json.simple.parser.JSONParser;
import cc.twittertools.search.api.TrecSearchThriftClient;
import cc.twittertools.thrift.gen.TResult;
import edu.gslis.eval.Qrels;
import edu.gslis.queries.GQueries;
import edu.gslis.queries.GQueriesJsonImpl;
import edu.gslis.queries.GQuery;
import edu.gslis.ttg.clusters.Cluster;
import edu.gslis.ttg.clusters.Clusters;
import edu.gslis.ttg.clusters.clusterers.SimpleJaccardClusterer;
import edu.gslis.ttg.searchers.SimpleSearcher;
import edu.gslis.utils.ParameterBroker;
public class RunTTGBaseline {
private static final String DEFAULT_RUNTAG = "ttgbaseline";
private static final String HOST_OPTION = "host";
private static final String TRAINING_PORT = "training_port";
private static final String TESTING_PORT = "testing_port";
private static final String QUERIES_OPTION = "queries";
private static final String NUM_RESULTS_OPTION = "num_results";
private static final String GROUP_OPTION = "group";
private static final String TOKEN_OPTION = "token";
private static final String RUNTAG_OPTION = "runtag";
private static final String JACCARD_STEP_SIZE = "jaccard_step";
private static final String TRAINING_QUERIES = "training_queries";
private static final String TRAINING_CLUSTERS = "training_clusters";
private static final String QRELS_OPTION = "qrels";
private static final String EVALUATION_OPTION = "evaluation_type";
public static void main(String[] args) throws NumberFormatException, TException, UnsupportedEncodingException {
ParameterBroker params = new ParameterBroker("./config/run_params.json");
PrintStream out = new PrintStream(System.out, true, "UTF-8");
PrintStream err = new PrintStream(System.err, true, "UTF-8");
GQueries trainingQueries = new GQueriesJsonImpl();
trainingQueries.setMetadataField("querytweettime");
trainingQueries.read(params.getParamValue(TRAINING_QUERIES));
GQueries queries = new GQueriesJsonImpl();
queries.setMetadataField("querytweettime");
queries.read(params.getParamValue(QUERIES_OPTION));
Qrels qrels = new Qrels(params.getParamValue(QRELS_OPTION), false, 1);
// max number of docs to send to output
int numResults = 1000;
try {
if (params.getParamValue(NUM_RESULTS_OPTION) != null) {
numResults = Integer.parseInt(params.getParamValue(NUM_RESULTS_OPTION));
}
} catch (NumberFormatException e) {
err.println("Invalid " + NUM_RESULTS_OPTION + ": " + params.getParamValue(NUM_RESULTS_OPTION));
System.exit(-1);
}
// authentication credentials
String group = params.getParamValue(GROUP_OPTION);
if(group==null) {
err.println("Invalid " + GROUP_OPTION + ": must set a valid group ID");
System.exit(-1);
}
String token = params.getParamValue(TOKEN_OPTION);
if(group==null) {
err.println("Invalid " + TOKEN_OPTION + ": must set a valid authentication token");
System.exit(-1);
}
// ports
int trainingPort = 9090;
try {
if (params.getParamValue(TRAINING_PORT) != null) {
trainingPort = Integer.parseInt(params.getParamValue(TRAINING_PORT));
}
} catch (NumberFormatException e) {
err.println("Invalid " + TRAINING_PORT + ": " + params.getParamValue(TRAINING_PORT));
System.exit(-1);
}
int testingPort = 9091;
try {
if (params.getParamValue(TESTING_PORT) != null) {
testingPort = Integer.parseInt(params.getParamValue(TESTING_PORT));
}
} catch (NumberFormatException e) {
err.println("Invalid " + TESTING_PORT + ": " + params.getParamValue(TESTING_PORT));
System.exit(-1);
}
// run tag
String runTag = params.getParamValue(RUNTAG_OPTION);
if(runTag==null) {
runTag = DEFAULT_RUNTAG;
}
// jaccard step size
double stepSize = 0.1;
try {
if (params.getParamValue(JACCARD_STEP_SIZE) != null) {
stepSize = Double.parseDouble(params.getParamValue(JACCARD_STEP_SIZE));
}
} catch (NumberFormatException e) {
err.println("Invalid " + JACCARD_STEP_SIZE + ": " + params.getParamValue(JACCARD_STEP_SIZE));
System.exit(-1);
}
// weighted or unweighted evaluation
String evalType = "weighted";
try {
if (params.getParamValue(EVALUATION_OPTION) != null) {
evalType = params.getParamValue(EVALUATION_OPTION);
}
} catch (Exception e) {
err.println("Invalid " + EVALUATION_OPTION + ": " + params.getParamValue(EVALUATION_OPTION));
System.exit(-1);
}
// 2 decimal places
DecimalFormat df = new DecimalFormat("#.##");
// read in training data
String trainingFile = params.getParamValue(TRAINING_CLUSTERS);
if (trainingFile==null) {
err.println("Invalid " + TRAINING_CLUSTERS + ": please provide valid file.");
System.exit(-1);
}
// parse training data into clusters
Map<String, Clusters> clusterMembership = new HashMap<String, Clusters>();
JSONParser parser = new JSONParser();
try {
JSONObject parseObj = (JSONObject) parser.parse(new FileReader(trainingFile));
JSONObject topicObj = (JSONObject) parseObj.get("topics");
Set<String> topics = topicObj.keySet();
Iterator<String> topicIt = topics.iterator();
while (topicIt.hasNext()) { // for each topic
String topic = topicIt.next();
clusterMembership.put(topic, new Clusters());
JSONArray clusters = (JSONArray) ((JSONObject) topicObj.get(topic)).get("clusters");
Iterator<JSONArray> clusterIt = clusters.iterator();
while (clusterIt.hasNext()) { // for each cluster in the topic
JSONArray cluster = (JSONArray) clusterIt.next();
Cluster c = new Cluster();
Iterator<String> clusterMemberIt = cluster.iterator();
while (clusterMemberIt.hasNext()) { // for each docId in the cluster
String member = clusterMemberIt.next();
long memberId = Long.parseLong(member);
c.add(memberId);
}
clusterMembership.get(topic).add(c);
}
}
} catch (Exception e) {
err.println("Error reading training data.");
e.printStackTrace();
System.exit(-1);
}
// instantiate search client
TrecSearchThriftClient client = new TrecSearchThriftClient(params.getParamValue(HOST_OPTION),
trainingPort, group, token);
SimpleSearcher searcher = new SimpleSearcher(client, numResults);
err.println("=== Train Queries ===");
List<Double> thresholds = new ArrayList<Double>();
double averageThreshold = 0;
Iterator<GQuery> queryIterator = trainingQueries.iterator();
while(queryIterator.hasNext()) {
GQuery query = queryIterator.next();
Map<Long, TResult> seenResults = searcher.search(query);
SimpleJaccardClusterer clusterer = new SimpleJaccardClusterer(new ArrayList<TResult>(seenResults.values()));
// sweep through jaccard steps, calculating F1
double maxF1 = 0;
double maxF1Threshold = 1;
for (double j = 1.0; j >= 0.0; j -= stepSize) { // for each jaccard threshold step
Clusters clusters = clusterer.cluster(j);
// all clusters are created now, get a finalized set of results
Set<Long> allResults = new HashSet<Long>(seenResults.keySet());
allResults.removeAll(clusters.getAllClusteredResults()); // allResults includes unclustered plus one representative from each cluster
for (Cluster c : clusters) {
allResults.add(c.getFirstMember());
}
// calculate f1 on the finalized set
Clusters seenClusters = new Clusters();
Clusters trueClusters = clusterMembership.get(query.getTitle());
Iterator<Long> resultIt = allResults.iterator();
while (resultIt.hasNext()) {
long result = resultIt.next();
Cluster trueCluster = trueClusters.findCluster(result);
if (trueCluster != null) { // if it is relevant, it will have a true cluster; if this is null, it's non-relevant
seenClusters.add(trueCluster);
}
}
int numRetrievedClusters = seenClusters.size();
int numResultsReturned = allResults.size();
int numTrueClusters = trueClusters.size();
double precision = 0;
double recall = 0;
double f1 = 0;
if (evalType.equals("unweighted")) {
precision = numRetrievedClusters / (double) numResultsReturned;
recall = numRetrievedClusters / (double) numTrueClusters;
f1 = 2 * precision * recall / (precision + recall);
} else {
// for weighted measurements, we need the weight of each cluster
int retrievedWeight = 0;
for (Cluster cluster : seenClusters) {
int w = cluster.getWeight(query, qrels);
retrievedWeight += w;
}
int resultsWeight = 0;
for (long result : allResults) {
int w = 0;
if (seenClusters.findCluster(result) == null)
resultsWeight += w;
}
int trueWeight = 0;
for (Cluster cluster : trueClusters) {
int w = cluster.getWeight(query, qrels);
trueWeight += w;
}
precision = retrievedWeight / (double) resultsWeight; // <--- ??????
recall = retrievedWeight / (double) trueWeight;
f1 = 2 * precision * recall / (precision + recall);
}
if (f1 > maxF1) {
maxF1 = f1;
maxF1Threshold = j;
}
}
thresholds.add(maxF1Threshold);
err.println("F1: "+df.format(maxF1)+"; Jaccard: "+df.format(maxF1Threshold));
}
// get the average threshold
for (double threshold : thresholds) {
averageThreshold += threshold;
}
averageThreshold /= thresholds.size();
err.println("Average Jaccard: "+averageThreshold);
err.println("=== Test Queries ===");
// now cluster the test queries and output
queryIterator = queries.iterator();
while(queryIterator.hasNext()) {
GQuery query = queryIterator.next();
err.println(query.getTitle());
client = new TrecSearchThriftClient(params.getParamValue(HOST_OPTION), testingPort, group, token);
searcher = new SimpleSearcher(client, numResults);
Map<Long, TResult> seenResults = searcher.search(query);
SimpleJaccardClusterer clusterer = new SimpleJaccardClusterer(new ArrayList<TResult>(seenResults.values()));
Clusters clusters = clusterer.cluster(averageThreshold);
// all clusters are created now, get a finalized set of results
Set<Long> allResults = new HashSet<Long>(seenResults.keySet());
allResults.removeAll(clusters.getAllClusteredResults()); // allResults includes unclustered plus one representative from each cluster
for (Cluster c : clusters) {
allResults.add(c.getFirstMember());
}
int i = 0;
for (long result : allResults) {
TResult hit = seenResults.get(result);
out.println(String.format("%s Q0 %s %d %f %s", query.getTitle(), hit.getId(), i, hit.getRsv(), runTag));
i++;
}
}
out.close();
err.close();
}
}