/*
* Copyright (C) 2011 Alasdair C. Hamilton
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>
*/
package ket.treeDiff;
import java.io.*;
import java.util.*;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.awt.Graphics2D;
import ket.math.*;
import geom.Offset;
import geom.Position;
import java.awt.Color;
import ket.display.box.Box;
import ketUI.panel.KetPanel;
import ket.display.Transition;
import ket.display.ImageTools;
import ket.display.ColourScheme;
import ket.display.ColourSchemeDecorator;
import ketUI.Ket;
/**
* This class associates pairs of elements in pairs of equations, where
* possible explaining how one equation was transformed into another. This has
* applications in displaying the animated transition due to edits.
*/
public class TreeDiff {
final Argument before;
final Argument after;
final ArgumentVector rowVector;
final ArgumentVector columnVector;
final Cell[][] table;
final Vector<Step> steps;
static final boolean DEBUG = false;
public TreeDiff(Argument before, Argument after) {
this.before = before;
this.after = after;
this.rowVector = new ArgumentVector(before, ArgumentVector.INCLUDE_ROOT);
this.columnVector = new ArgumentVector(after, ArgumentVector.INCLUDE_ROOT);
table = new Cell[rowVector.size()][columnVector.size()];
initTable();
steps = new Vector<Step>();
calcSteps();
}
private void initTable() {
for (int i=0; i<rows(); i++) {
Argument b = rowVector.get(i);
for (int j=0; j<cols(); j++) {
Argument a = columnVector.get(j);
table[i][j] = new Cell(b, a);
}
}
}
private void markAndAddDescendents(int I, int J) {
ArgumentVector u = new ArgumentVector(rowVector.get(I), ArgumentVector.INCLUDE_ROOT);
ArgumentVector v = new ArgumentVector(columnVector.get(J), ArgumentVector.INCLUDE_ROOT);
assert u.size()==v.size();
for (int k=0; k<u.size(); k++) {
Argument m = u.get(k);
Argument n = v.get(k);
int i = m.indexIn(rowVector);
int j = n.indexIn(columnVector);
markUsedRowAndColumn(i, j);
steps.add(new Step(m, n));
}
if (DEBUG) Ket.out.println();
}
private void calcSteps() {
if (DEBUG) Ket.out.println();
if (DEBUG) Ket.out.printf("%60sCALC STEPS\n", "");
identifyRowPatterns();
identifyColumnPatterns();
// Exactly equal matches
if (DEBUG) Ket.out.println();
if (DEBUG) Ket.out.println("Exact matches:");
for (int i=0; i<rows(); i++) {
for (int j=0; j<cols(); j++) {
// Given some cells unambiguously match and their sub-branches equal one another, identify them
// and all their children.
Cell cell = get(i, j);
Step step = cell.identifyEquals();
if (step!=null) {
steps.add(step);
if (DEBUG) Ket.out.println("\t"+i+"->"+j+";'"+rowVector.get(i)+"'->'"+columnVector.get(j)+"'");
markAndAddDescendents(i, j);
}
}
}
if (DEBUG) Ket.out.println("size: " + steps.size());
// Unambiguous matches
if (DEBUG) Ket.out.println();
if (DEBUG) Ket.out.println("Unambiguous matches:");
for (int i=0; i<rows(); i++) {
for (int j=0; j<cols(); j++) {
// Identify all unambiguous matches (possibly with slightly different descendants).
Cell cell = get(i, j);
Step step = cell.identifyUnabmiguous();
if (step!=null) {
// Note: cell.before and cell.after can still be different.
steps.add(step);
markUsedRowAndColumn(i, j);
if (DEBUG) Ket.out.println("\t"+i+"->"+j+";'"+rowVector.get(i)+"'->'"+columnVector.get(j)+"'");
}
}
}
if (DEBUG) Ket.out.println("size: " + steps.size());
// Ambiguous matches
if (DEBUG) Ket.out.println();
if (DEBUG) Ket.out.println("Ambiguous matches:");
HashMap<Integer, Set<Integer>> rowToColumns = new HashMap<Integer, Set<Integer>>();
for (int i=0; i<rows(); i++) {
// Identify unexplained rows and their corresponding unexplained columns.
if (unexplainedRow(i)) { // <--- refactor
Set<Integer> columns = getUnusedColumnsInRow(i);
if (columns.size()>=2) {
rowToColumns.put(i, columns);
}
}
}
if (DEBUG) Ket.out.println("\trow->cols");
if (DEBUG) Ket.out.println("\t\t" + rowToColumns);
HashMap<Set<Integer>, Set<Integer>> columnsToRows = new HashMap<Set<Integer>, Set<Integer>>();
for (int r : rowToColumns.keySet()) {
// Given a collection of columns associated with each unexplained row,
// Return a map from each column index pattern to the set of unexplained rows they appear in.
Set<Integer> columnsOfValidRow = rowToColumns.get(r);
ArgumentTools.appendToMapValue(columnsToRows, columnsOfValidRow, r);
}
if (DEBUG) Ket.out.println("\tcols->rows");
if (DEBUG) Ket.out.println("\t\t" + columnsToRows);
for (Set<Integer> cs : new Vector<Set<Integer>>(columnsToRows.keySet())) { // Shallow copy.
// Given a map from column index patterns to the unexplained rows in which they appear, exclude single rows.
Set<Integer> row = columnsToRows.get(cs);
// TODO: Add the single row cases?
if (row.size()<2) {
if (DEBUG) Ket.out.print("-");
columnsToRows.remove(cs);
}
}
if (DEBUG) Ket.out.println();
if (DEBUG) Ket.out.println("\tcols->rows'");
if (DEBUG) Ket.out.println("\t\t" + columnsToRows);
for (Set<Integer> cs : columnsToRows.keySet()) {
// For each unexplained sub-table:
// A sub-table is specified by a set of row indices and a set of column indices.
// Identify a 'permutation matrix', the best (i.e. highest merit) row-column pairs, from this table.
Set<Integer> rs = columnsToRows.get(cs);
Map<Integer, Integer> columnToBestRow = findBestPairsByRow(rs, cs); // By row
if (DEBUG) Ket.out.println("columnToBestRow: " + columnToBestRow.values() + " -*> " + columnToBestRow.keySet());
for (int u : new HashSet<Integer>(columnToBestRow.values())) { // cs'th group: row
if (DEBUG) Ket.out.print("\t\t");
for (int v : columnToBestRow.keySet()) { // cs'th group : col
Cell cell = get(u, v);
if (DEBUG) Ket.out.printf("%8s->%-8s ", rowVector.get(u), columnVector.get(v));
}
if (DEBUG) Ket.out.println();
}
Set<Integer> rows = new HashSet<Integer>(columnToBestRow.values());
for (int r : rows) {
int bestColumn = findBestColumn(r, columnToBestRow);
Cell cell = get(r, bestColumn);
steps.add(new Step(rowVector.get(r), columnVector.get(bestColumn)));
if (DEBUG) Ket.out.println("\t"+r+"->"+bestColumn+";'"+rowVector.get(r)+"'->'"+columnVector.get(bestColumn)+"'");
markUsedRowAndColumn(r, bestColumn); // cell
}
}
if (DEBUG) Ket.out.println("table:");
Map<Integer, Integer> remains = new HashMap<Integer, Integer>(); // rows to columns
for (int i=0; i<rows(); i++) {
if (DEBUG) Ket.out.print("\t");
for (int j=0; j<cols(); j++) {
Cell cell = get(i, j);
double merit = cell.getMerit();
if (merit>0.0) {
if (DEBUG) Ket.out.printf("%6.3g", merit );
} else {
if (DEBUG) Ket.out.print(" .");
}
if (cell.has(Flags.EXPLAINED)) {
if (DEBUG) Ket.out.print(":E");
} else if (cell.has(Flags.USED)) {
if (DEBUG) Ket.out.print(" ");
} else {
if (DEBUG) Ket.out.print(":!");
remains.put(i, j);
}
}
if (DEBUG) Ket.out.println();
}
if (remains.size()==1) {
int r = remains.keySet().iterator().next();
int c = remains.get(r);
// If there is a single leftover, just add it.
steps.add(new Step(rowVector.get(r), columnVector.get(c)));
markUsedRowAndColumn(r, c); // cell
}
if (DEBUG) Ket.out.println("size: " + steps.size());
}
public Vector<Transition> getTransitions(Box beforeRootBox, Box afterRootBox) {
// Ensure that beforeRootBox and afterRootBox have been set.
Vector<Transition> transitions = new Vector<Transition>();
for (Step step : steps) {
Transition t = step.getTransition(beforeRootBox, afterRootBox);
if (t!=null) {
transitions.add(t);
}
}
return transitions;
}
public Argument getBefore() {
return before;
}
public Argument getAfter() {
return after;
}
//////////////////////
// INTERNAL METHODS //
//////////////////////
/**
* Iterate through the rows and extract information on each.
*/
private void identifyRowPatterns() {
for (int i=0; i<rows(); i++) {
Cell match = findUniqueMatch(getRow(i));
if (match!=null) {
match.addBeforeFlag(Flags.UNIQUELY_EQUAL);
}
}
}
/**
* Iterate through the columns and extract information on each.
*/
private void identifyColumnPatterns() {
for (int j=0; j<cols(); j++) {
Cell match = findUniqueMatch(getColumn(j));
if (match!=null) {
match.addAfterFlag(Flags.UNIQUELY_EQUAL);
}
}
}
/**
* Extract information about a given row or column of elements such as
* uniqueness of matches then update the this and the cells' properties
* as required.
*/
private static Cell findUniqueMatch(Cell[] elements) {
Cell uniquelyEqual = null; // No matches.
for (Cell cell : elements) {
if (cell.has(Flags.ELEMENT_EQUALS)) {
if (uniquelyEqual!=null) {
// A match was already found.
return null;
}
uniquelyEqual = cell; // First match.
}
}
return uniquelyEqual;
}
private int rows() {
return rowVector.size();
}
private int cols() {
return columnVector.size();
}
private Cell get(int i, int j) {
return table[i][j];
}
private Cell[] getRow(int i) {
Cell[] row = new Cell[cols()];
for (int j=0; j<cols(); j++) {
row[j] = table[i][j];
}
return row;
}
private Cell[] getColumn(int j) {
Cell[] column = new Cell[rows()];
for (int i=0; i<rows(); i++) {
column[i] = table[i][j];
}
return column;
}
/**
* If table cell (i,j) is explained, then mark all cells in the same
* row and column as used.
*/
private void markUsedRowAndColumn(int i, int j) {
get(i, j).add(Flags.EXPLAINED);
for (int J=0; J<cols(); J++) {
get(i, J).add(Flags.USED);
}
for (int I=0; I<rows(); I++) {
get(I, j).add(Flags.USED);
}
}
/**
* If all elements of column i are unexplained, iterate through row i
* and return the set of non-zero merit cells.
*/
private Set<Integer> getUnusedColumnsInRow(int i) {
Set<Integer> column = new HashSet<Integer>();
for (int j=0; j<cols(); j++) {
//- if (get(i, j).getMerit()>0) { //?
if (!get(i, j).has(Flags.USED)) {
column.add(j);
}
}
return column;
}
/**
* If every cell in row i is unexplained, return true.
*/
private boolean unexplainedRow(int i) {
for (int j=0; j<cols(); j++) {
if (get(i, j).has(Flags.EXPLAINED)) {
return false;
}
}
return true;
}
/**
* Return the best cell match (by merit value) from those specified by
* mappings of columnToBestRow.
*/
private int findBestColumn(Integer rowIndex, Map<Integer, Integer> columnToBestRow) {
int bestColumn = -1;
double bestMerit = -1;
for (int column : columnToBestRow.keySet()) {
if (columnToBestRow.get(column)!=rowIndex) {
continue;
}
double merit = get(rowIndex, column).getMerit();
if (bestColumn==-1 || merit>bestMerit) {
bestColumn = column;
bestMerit = merit;
}
}
return bestColumn;
}
private Map<Integer, Integer> findBestPairsByRow(Set<Integer> rows, Set<Integer> columns) {
Map<Integer, Integer> columnToBestRow = new HashMap<Integer, Integer>();
for (int c : columns) {
// For each column, identify the row for which the cell's merit is largest.
int bestRow = -1;
double bestMerit = -1;
for (int r : rows) {
double merit = get(r, c).getMerit();
if (bestRow==-1 || merit>bestMerit) {
bestMerit = merit;
bestRow = r;
}
}
columnToBestRow.put(c, bestRow);
}
return columnToBestRow;
}
}