package ch.akuhn.hapax.index;
import static ch.akuhn.foreach.For.matrix;
import static ch.akuhn.foreach.For.range;
import static ch.akuhn.foreach.For.withIndex;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.Iterator;
import ch.akuhn.foreach.Each;
import ch.akuhn.foreach.EachXY;
import ch.akuhn.hapax.corpus.Terms;
import ch.akuhn.hapax.linalg.Matrix;
import ch.akuhn.hapax.linalg.SVD;
import ch.akuhn.hapax.linalg.SymetricMatrix;
import ch.akuhn.util.Files;
import ch.akuhn.util.PrintOn;
import ch.akuhn.util.Providable;
import ch.akuhn.util.Bag.Count;
public class LatentSemanticIndex implements Serializable {
private static final long serialVersionUID = 1337L;
private static final int VERSION_1 = 0x20080830;
private AssociativeList<String> documents;
private AssociativeList<String> terms;
private SVD svd;
private int[] documentLength;
private double[] globalWeighting;
@SuppressWarnings("unchecked")
private void readObject(ObjectInputStream in) throws Exception {
int version = in.readInt();
if (version != VERSION_1) throw new Error();
terms = new AssociativeList<String>((Iterable<String>) in.readObject());
documents = new AssociativeList<String>((Iterable<String>) in.readObject());
svd = (SVD) in.readObject();
if (!this.invariant()) throw new Error();
}
private boolean invariant() {
return true; // TODO Auto-generated method stub
}
private void writeObject(ObjectOutputStream out) throws Exception {
out.writeInt(VERSION_1);
out.writeObject(terms.asList());
out.writeObject(documents.asList());
out.writeObject(svd);
}
public LatentSemanticIndex(
AssociativeList<String> terms,
AssociativeList<String> documents,
SVD svd) {
this.documents = documents;
this.terms = terms;
this.svd = svd;
if (svd.getRank() == 0) return;
if (svd.rowCount() != terms.size()) this.svd = svd.transposed();
assert svd.rowCount() == terms.size();
assert svd.columnCount() == documents.size();
this.assertInvariant();
}
public double[] createPseudoDocument(String string) {
// apply: CamelCaseScanner, PorterStemmer, toLowerCase, and weighting
Terms query = new Terms(string).toLowerCase().stem();
return createPseudoDocument(query);
}
public double[] createPseudoDocument(Terms query) {
double[] weightings = new double[termCount()];
// iterate over query, assume: quert.size() <<< terms.size()
for (Count<String> each: query.counts()) {
int index = terms.get(each.element);
if (index < 0) continue;
double weight = (globalWeighting == null ? 1 : globalWeighting[index]);
weightings[index] = each.count * weight;
}
return svd.makePseudoV(weightings);
}
public Ranking<String> rankDocumentsByDocument(String d) {
Ranking<String> ranking = new Ranking<String>();
int n = documents.get(d);
for (Each<String> each: withIndex(documents)) {
ranking.add(each.value, svd.similarityVV(n, each.index));
}
return ranking.sort();
}
public Ranking<String> rankDocumentsByQuery(String query) {
Ranking<String> ranking = new Ranking<String>();
double[] pseudo = createPseudoDocument(query);
for (Each<String> each: withIndex(documents)) {
ranking.add(each.value, svd.similarityV(each.index, pseudo));
}
return ranking.sort();
}
public Ranking<String> rankDocumentsByQuery(Terms query) {
Ranking<String> ranking = new Ranking<String>();
double[] pseudo = createPseudoDocument(query);
for (Each<String> each: withIndex(documents)) {
ranking.add(each.value, svd.similarityV(each.index, pseudo));
}
return ranking.sort();
}
public Ranking<String> rankDocumentsByTerm(String term) {
Ranking<String> ranking = new Ranking<String>();
int n = terms.get(term);
assert n >= 0;
for (Each<String> each: withIndex(documents)) {
ranking.add(each.value, svd.similarityUV(n, each.index));
}
return ranking.sort();
}
public Ranking<CharSequence> rankTermsByDocument(String d) {
Ranking<CharSequence> ranking = new Ranking<CharSequence>();
int n = documents.get(d);
for (Each<String> each: withIndex(terms)) {
ranking.add(each.value, svd.similarityUV(each.index, n));
}
return ranking.sort();
}
public Ranking<CharSequence> rankTermsByTerm(String term) {
Ranking<CharSequence> ranking = new Ranking<CharSequence>();
int n = terms.get(term);
for (Each<String> each: withIndex(terms)) {
ranking.add(each.value, svd.similarityUU(n, each.index));
}
return ranking.sort();
}
public Matrix documentCorrelation() {
Matrix correlation = new SymetricMatrix(documents.size());
for (int row: range(documents.size())) {
for (int column: range(documents.size())) {
correlation.put(row, column, svd.similarityVV(row, column));
}
}
//Appendable file = Files.openWrite("document-correlation.json");
//correlation.storeOn(file);
//Files.close(file);
return correlation;
}
public Matrix euclidianDistance() {
Matrix dist = new SymetricMatrix(documents.size());
for (int row: range(documents.size())) {
for (int column: range(documents.size())) {
dist.put(row, column, svd.euclidianVV(row, column));
}
}
return dist;
}
public Iterable<Double> documentCorrelations() {
return new Providable<Double>() {
private Iterator<EachXY> iter;
@Override
public void initialize() {
iter = matrix(documents.size(), documents.size()).iterator();
}
@Override
public Double provide() {
if (!iter.hasNext()) return done();
EachXY each = iter.next();
return svd.similarityVV(each.x, each.y);
}
};
}
public int documentCount() {
return documents.size();
}
public int termCount() {
return terms.size();
}
public void updateDocument(String doc, Terms contents) {
double[] newDocument = createPseudoDocument(contents);
int index = documents.get(doc);
if (index >= 0) {
svd = svd.withReplaceV(index, newDocument);
}
else {
documents.add(doc);
svd = svd.withAppendV(newDocument);
}
this.assertInvariant();
}
public void removeDocument(String doc) {
int index = documents.remove(doc);
if (index < 0) return;
svd = svd.withoutV(index);
// todo adapate termCounts
this.assertInvariant();
}
public Iterable<String> documents() {
return documents;
}
public LatentSemanticIndex initializeDocumentLength(int[] lengthArray) {
this.documentLength = lengthArray;
this.assertInvariant();
return this;
}
public int getDocumentLength(String doc) {
int index = documents.get(doc);
if (index < 0) return -1;
return documentLength[index];
}
private void assertInvariant() {
if (documentLength != null && documentLength.length != documents.size()) throw new AssertionError();
if (svd.columnCount() != documents.size()) throw new AssertionError();
if (svd.rowCount() != terms.size()) throw new AssertionError();
if (globalWeighting != null && globalWeighting.length != terms.size()) throw new AssertionError();
}
public LatentSemanticIndex initializeGlobalWeightings(double[] globalWeightings) {
this.globalWeighting = globalWeightings;
return this;
}
}