package edu.wiki.modify;
import gnu.trove.TIntDoubleHashMap;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.security.NoSuchAlgorithmException;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermEnum;
import org.apache.lucene.index.TermFreqVector;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
/**
* Reads TF and IDF from the index and
* writes cosine-normalized TF.IDF values to database.
*
* Normalization is performed as in Gabrilovich et al. (2009)
*
* Usage: IndexModifier <Lucene index location>
*
* @author Cagatay Calli <ccalli@gmail.com>
*
*/
public class MemIndexModifier {
static Connection connection = null;
static Statement stmtLink;
static PreparedStatement pstmtVector;
// static String strLoadData = "LOAD DATA LOCAL INFILE 'vector.txt' INTO TABLE idx FIELDS ENCLOSED BY \"'\"";
static String strVectorQuery = "INSERT INTO idx (term,vector) VALUES (?,?)";
static String strTermLoadData = "LOAD DATA LOCAL INFILE 'term.txt' INTO TABLE terms FIELDS ENCLOSED BY \"'\"";
static String strAllInlinks = "SELECT target_id,inlink FROM inlinks";
static String strLimitQuery = "SELECT COUNT(id) FROM article;";
private static IndexReader reader = null;
static int limitID;
private static TIntDoubleHashMap inlinkMap;
static int WINDOW_SIZE = 100;
static float WINDOW_THRES= 0.005f;
static DecimalFormat df = new DecimalFormat("#.########");
static public class DocScore implements Comparable<DocScore> {
int doc;
float score;
public DocScore(int doc, float score) {
this.doc = doc;
this.score = score;
}
@Override
public int compareTo(DocScore o) {
float val = (this.score - o.score);
if(val < 0){
return 1; // descending
}
else if(val > 0){
return -1;
}
return 0;
}
}
/**
* global, term-doc matrix
*/
static HashMap<String, ArrayList<DocScore>> matrix;
public static void initDB() throws ClassNotFoundException, SQLException, IOException {
// Load the JDBC driver
String driverName = "com.mysql.jdbc.Driver"; // MySQL Connector
Class.forName(driverName);
// read DB config
InputStream is = MemIndexModifier.class.getResourceAsStream("/config/db.conf");
BufferedReader br = new BufferedReader(new InputStreamReader(is));
String serverName = br.readLine();
String mydatabase = br.readLine();
String username = br.readLine();
String password = br.readLine();
br.close();
// Create a connection to the database
String url = "jdbc:mysql://" + serverName + "/" + mydatabase + "?useUnicode=yes&characterEncoding=UTF-8"; // a JDBC url
connection = DriverManager.getConnection(url, username, password);
stmtLink = connection.createStatement();
stmtLink.setFetchSize(200);
stmtLink.execute("DROP TABLE IF EXISTS idx");
stmtLink.execute("CREATE TABLE idx (" +
"term VARBINARY(255)," +
"vector MEDIUMBLOB " +
") DEFAULT CHARSET=binary");
stmtLink.execute("DROP TABLE IF EXISTS terms");
stmtLink.execute("CREATE TABLE terms (" +
"term VARBINARY(255)," +
"idf FLOAT " +
") DEFAULT CHARSET=binary");
stmtLink = connection.createStatement();
ResultSet res = stmtLink.executeQuery(strLimitQuery);
res.next();
limitID = res.getInt(1);
// read inlink counts
inlinkMap = new TIntDoubleHashMap(limitID);
int targetID, numInlinks;
res = stmtLink.executeQuery(strAllInlinks);
while(res.next()){
targetID = res.getInt(1);
numInlinks = res.getInt(2);
inlinkMap.put(targetID, Math.log(1+Math.log(1+numInlinks)));
}
pstmtVector = connection.prepareStatement(strVectorQuery);
}
/**
* @param args
* @throws IOException
* @throws SQLException
* @throws ClassNotFoundException
* @throws NoSuchAlgorithmException
*/
public static void main(String[] args) throws IOException, ClassNotFoundException, SQLException {
try {
Directory fsdir = FSDirectory.open(new File(args[0]));
reader = IndexReader.open(fsdir,true);
} catch (Exception ex) {
System.out.println("Cannot create index..." + ex.getMessage());
System.exit(-1);
}
initDB();
long sTime, eTime;
sTime = System.currentTimeMillis();
int maxid = reader.maxDoc();
TermFreqVector tv;
String[] terms;
String term = "";
Term t;
int tcount;
int tfreq = 0;
float idf;
float tf;
float tfidf;
double inlinkBoost;
double sum;
int wikiID;
int hashInt;
int numDocs = reader.numDocs();
TermEnum tnum = reader.terms();
HashMap<String, Float> idfMap = new HashMap<String, Float>(500000);
HashMap<String, Float> tfidfMap = new HashMap<String, Float>(5000);
HashMap<String, Integer> termHash = new HashMap<String, Integer>(500000);
tnum = reader.terms();
hashInt = 0;
tcount = 0;
while(tnum.next()){
t = tnum.term();
term = t.text();
tfreq = tnum.docFreq(); // get DF for the term
// skip rare terms
if(tfreq < 3){
continue;
}
// idf = (float)(Math.log(numDocs/(double)(tfreq+1)) + 1.0);
idf = (float)(Math.log(numDocs/(double)(tfreq)));
// idf = (float)(Math.log(numDocs/(double)(tfreq)) / Math.log(2));
idfMap.put(term, idf);
termHash.put(term, hashInt++);
tcount++;
}
matrix = new HashMap<String, ArrayList<DocScore>>(tcount);
for(int i=0;i<maxid;i++){
if(!reader.isDeleted(i)){
//System.out.println(i);
wikiID = Integer.valueOf(reader.document(i).getField("id").stringValue());
inlinkBoost = inlinkMap.get(wikiID);
tv = reader.getTermFreqVector(i, "contents");
try {
terms = tv.getTerms();
int[] fq = tv.getTermFrequencies();
sum = 0.0;
tfidfMap.clear();
// for all terms of a document
for(int k=0;k<terms.length;k++){
term = terms[k];
if(!idfMap.containsKey(term))
continue;
tf = (float) (1.0 + Math.log(fq[k]));
// tf = (float) (1.0 + Math.log(fq[k]) / Math.log(2));
idf = idfMap.get(term);
tfidf = (float) (tf * idf);
tfidfMap.put(term, tfidf);
sum += tfidf * tfidf;
}
sum = Math.sqrt(sum);
// for all terms of a document
for(int k=0;k<terms.length;k++){
term = terms[k];
if(!idfMap.containsKey(term))
continue;
tfidf = (float) (tfidfMap.get(term) / sum * inlinkBoost);
// System.out.println(i + ": " + term + " " + fq[k] + " " + tfidf);
// ++++ record to DB (term,doc,tfidf) +++++
if(matrix.containsKey(term)){
matrix.get(term).add(new DocScore(wikiID, tfidf));
}
else {
ArrayList<DocScore> dsl = new ArrayList<DocScore>();
dsl.add(new DocScore(wikiID, tfidf));
matrix.put(term, dsl);
}
}
}
catch(Exception e){
e.printStackTrace();
System.out.println("ERR: " + wikiID + " " + tv);
continue;
}
}
}
int doc;
float score;
// for pruning
int mark, windowMark;
float first = 0, last = 0, highest = 0;
float [] window = new float[WINDOW_SIZE];
for(String k : matrix.keySet()){
term = k;
List<DocScore> ds = matrix.get(k);
Collections.sort(ds);
// prune and write the vector
{
// prune the vector
mark = 0;
windowMark = 0;
highest = first = last = 0;
ByteArrayOutputStream baos = new ByteArrayOutputStream(50000);
DataOutputStream tdos = new DataOutputStream(baos);
for(DocScore d : ds){
doc = d.doc;
score = d.score;
// sliding window
window[windowMark] = score;
if(mark == 0){
highest = score;
first = score;
}
if(mark < WINDOW_SIZE){
tdos.writeInt(doc);
tdos.writeFloat(score);
}
else if( highest*WINDOW_THRES < (first - last) ){
tdos.writeInt(doc);
tdos.writeFloat(score);
if(windowMark < WINDOW_SIZE-1){
first = window[windowMark+1];
}
else {
first = window[0];
}
}
else {
// truncate
break;
}
last = score;
mark++;
windowMark++;
windowMark = windowMark % WINDOW_SIZE;
}
ByteArrayOutputStream dbvector = new ByteArrayOutputStream();
DataOutputStream dbdis = new DataOutputStream(dbvector);
dbdis.writeInt(mark);
dbdis.flush();
dbvector.write(baos.toByteArray());
dbvector.flush();
dbdis.close();
// write to DB
pstmtVector.setString(1, term);
pstmtVector.setBlob(2, new ByteArrayInputStream(dbvector.toByteArray()));
pstmtVector.execute();
tdos.close();
baos.close();
}
}
// record term IDFs
FileOutputStream tos = new FileOutputStream("term.txt");
OutputStreamWriter tsw = new OutputStreamWriter(tos,"UTF-8");
for(String tk : idfMap.keySet()){
tsw.write("'" + tk.replace("\\","\\\\").replace("'","\\'") + "'\t"+idfMap.get(tk)+"\n");
}
tsw.close();
stmtLink.execute(strTermLoadData);
stmtLink.execute("CREATE INDEX idx_term ON terms (term(32))");
eTime = System.currentTimeMillis();
System.out.println("Total TIME (sec): "+ (eTime-sTime)/1000.0);
reader.close();
connection.close();
}
}