package ch.akuhn.hapax.index;
import static ch.akuhn.foreach.For.withIndex;
import java.util.NoSuchElementException;
import ch.akuhn.foreach.Each;
import ch.akuhn.hapax.corpus.Corpus;
import ch.akuhn.hapax.corpus.PorterStemmer;
import ch.akuhn.hapax.corpus.Stemmer;
import ch.akuhn.hapax.corpus.Stopwords;
import ch.akuhn.hapax.corpus.Terms;
import ch.akuhn.hapax.linalg.SVD;
import ch.akuhn.hapax.linalg.SparseMatrix;
import ch.akuhn.hapax.linalg.Vector;
import ch.akuhn.hapax.linalg.Vector.Entry;
import ch.akuhn.util.IntArray;
import ch.akuhn.util.Pair;
import ch.akuhn.util.Bag.Count;
public class TermDocumentMatrix extends Corpus {
private static final int DEFAULT_DIMENSIONS = 25;
private AssociativeList<String> documents; // columns
private double[] globalWeightings;
private IntArray lengthArray;
private SparseMatrix matrix;
private AssociativeList<String> terms; // rows
public TermDocumentMatrix() {
this.matrix = new SparseMatrix(0, 0);
this.terms = new AssociativeList<String>();
this.documents = new AssociativeList<String>();
this.lengthArray = new IntArray();
}
private TermDocumentMatrix(AssociativeList<String> terms, AssociativeList<String> documents, IntArray lengthArray) {
this.matrix = new SparseMatrix(terms.size(), documents.size());
this.terms = terms.clone();
this.documents = documents.clone();
this.lengthArray = lengthArray.clone();
}
private void addToRow(String term, Vector values) {
int row = indexTerm(term);
matrix.addToRow(row, values);
}
@Override
public void putDocument(String doc, Terms bag) {
int column = this.indexDocument(doc);
lengthArray.add(column, bag.size());
for (Count<String> term: bag.counts()) {
int row = this.indexTerm(term.element);
matrix.add(row, column, term.count);
}
}
@Override
public boolean containsDocument(String doc) {
return documents.contains(doc);
}
public LatentSemanticIndex createIndex() {
return this.createIndex(DEFAULT_DIMENSIONS);
}
public LatentSemanticIndex createIndex(int dimensions) {
return new LatentSemanticIndex(terms, documents, new SVD(matrix, dimensions))
.initializeGlobalWeightings(globalWeightings)
.initializeDocumentLength(lengthArray.asIntArray());
}
public double density() {
return matrix.density();
}
@Override
public Iterable<String> documents() {
return documents;
}
@Override
public int documentCount() {
return documents.size();
}
private int indexTerm(String term) {
int index = terms.add(term);
if (index == matrix.rowCount()) matrix.addRow();
return index;
}
private int indexDocument(String doc) {
int column = documents.add(doc);
if (column == matrix.columnCount()) matrix.addColumn();
return column;
}
public TermDocumentMatrix rejectAndWeight() {
return toLowerCase()
.rejectHapaxes()
.rejectStopwords()
.stem()
.weight(LocalWeighting.TERM, GlobalWeighting.IDF);
}
public TermDocumentMatrix rejectHapaxes() {
return rejectLegomena(1);
}
public TermDocumentMatrix rejectLegomena(int threshold) {
TermDocumentMatrix tdm = new TermDocumentMatrix(new AssociativeList<String>(), documents, lengthArray);
for (Pair<String,Vector> each: termRowPairs()) {
if (each.snd.used() <= threshold) continue;
tdm.addToRow(each.fst, each.snd);
}
return tdm;
}
public TermDocumentMatrix rejectStopwords() {
return rejectStopwords(Stopwords.BASIC_ENGLISH);
}
public TermDocumentMatrix rejectStopwords(Stopwords stopwords) {
TermDocumentMatrix tdm = new TermDocumentMatrix(new AssociativeList<String>(), documents, lengthArray);
for (Pair<String,Vector> each: termRowPairs()) {
if (stopwords.contains(each.fst)) continue;
tdm.addToRow(each.fst, each.snd);
}
return tdm;
}
public TermDocumentMatrix stem() {
return stem(new PorterStemmer());
}
public TermDocumentMatrix stem(Stemmer stemmer) {
TermDocumentMatrix tdm = new TermDocumentMatrix(new AssociativeList<String>(), documents, lengthArray);
for (Pair<String,Vector> each: termRowPairs()) {
tdm.addToRow(stemmer.stem(each.fst), each.snd);
}
return tdm;
}
// public void storeOn(Appendable app) {
// PrintOn out = new PrintOn(app);
// out.print("# Term-Document-Matrix").cr();
// out.print(this.termCount()).cr();
// for (String term: terms) {
// out.print(term).cr();
// }
// out.print(this.documentCount()).cr();
// for (Document doc: documents) {
// out.print(doc.name().replace(' ', '_')).tab().print(doc.version().replace(' ', '_')).cr();
// }
// matrix.storeSparseOn(app);
// }
// public void storeOn(String filename) {
// this.storeOn(Files.openWrite(filename));
// }
// public static TermDocumentMatrix readFrom(Scanner scan) {
// TermDocumentMatrix tdm = new TermDocumentMatrix();
// if (scan.hasNext("#")) scan.findInLine(".*");
//
// int termSize = scan.nextInt();
// for (int i = 0; i < termSize; i++) {
// String term = scan.next();
// tdm.indexTerm(term);
// }
// assert tdm.termCount() == termSize;
//
// int documentSize = scan.nextInt();
// for (int i = 0; i < documentSize; i++) {
// String name = scan.next();
// String version = scan.next();
// tdm.makeDocument(name, version);
// }
// assert tdm.documentCount() == documentSize;
//
// tdm.matrix = SparseMatrix.readFrom(scan);
//
// return tdm;
// }
@Override
public Terms terms() {
Terms bag = new Terms();
for (Pair<String,Vector> each: termRowPairs()) {
bag.add(each.fst, (int) each.snd.sum());
}
return bag;
}
private Iterable<Pair<String,Vector>> termRowPairs() {
return Pair.zip(terms, matrix.rows());
}
@Override
public int termCount() {
return terms.size();
}
public TermDocumentMatrix toLowerCase() {
TermDocumentMatrix tdm = new TermDocumentMatrix(new AssociativeList<String>(), documents, lengthArray);
for (Pair<String,Vector> each: termRowPairs()) {
tdm.addToRow(each.fst.toString().toLowerCase(), each.snd);
}
return tdm;
}
public TermDocumentMatrix weight(LocalWeighting localWeighting, GlobalWeighting globalWeighting) {
TermDocumentMatrix tdm = new TermDocumentMatrix(this.terms, this.documents, lengthArray);
tdm.globalWeightings = new double[terms.size()];
for (Each<Vector> row: withIndex(matrix.rows())) {
double global = tdm.globalWeightings[row.index] = globalWeighting.weight(row.value);
for (Entry column: row.value.entries()) {
tdm.matrix.put(row.index, column.index, localWeighting.weight(column.value) * global);
}
}
return tdm;
}
@Override
public Terms getDocument(String doc) {
int column = documents.get(doc);
if (column == -1) throw new NoSuchElementException();
Terms bag = new Terms();
for (Pair<String,Vector> each: termRowPairs()) {
int count = (int) each.snd.get(column);
bag.add(each.fst, count);
}
return bag;
}
// public TermDocumentMatrix copyUpto(String version, String[] versions) {
// TermDocumentMatrix copy = new TermDocumentMatrix();
// for (String each: versions) {
// for (Document doc: this.documents()) {
// if (doc.version().equals(version)) {
// copy.makeDocument(doc.name(), doc.version()).addTerms(doc.terms());
// }
// }
// if (version.equals(each)) return copy;
// }
// throw new Error();
// }
public SparseMatrix matrix() {
return matrix;
}
public int[] getAllDocumentLength() {
int[] length = new int[documents.size()];
for (Vector row: matrix.rows()) {
for (Entry each: row.entries()) {
length[each.index] += each.value;
}
}
return length;
}
public int getDocumentLength(String doc) {
int index = documents.get(doc);
if (index < 0) return -1;
return lengthArray.get(index);
}
}