assertEquals(13, tarr[1].get(1, 1), 0);
}
@Test
public void testMatrix() {
Matrix m = TensorFactory.tensor(5, 6);
assertEquals(5, m.getRows(), 0);
assertEquals(6, m.getColumns(), 0);
assertEquals(6, m.getDimensionElementsDistance(0), 0);
assertEquals(1, m.getDimensionElementsDistance(1), 0);
assertEquals(5, m.getDimensions()[0], 0);
assertEquals(6, m.getDimensions()[1], 0);
for (int i = 0; i < m.getElements().length; i++) {
m.getElements()[i] = i + 1;
}
assertEquals(2, m.get(0, 1), 0);
assertEquals(15, m.get(2, 2), 0);
m = TensorFactory.tensor(1, 6);
for (int i = 0; i < m.getElements().length; i++) {
m.getElements()[i] = i + 1;
}
assertEquals(2, m.get(0, 1), 0);
assertEquals(6, m.get(0, 5), 0);
m = TensorFactory.tensor(6, 1);
for (int i = 0; i < m.getElements().length; i++) {
m.getElements()[i] = i + 1;
}
assertEquals(2, m.get(1, 0), 0);
assertEquals(6, m.get(5, 0), 0);
// submatrix
Tensor t = TensorFactory.tensor(5, 5, 5);
float[] elements = t.getElements();
for (int i = 0; i < elements.length; i++) {
elements[i] = i + 1;
}
m = TensorFactory.tensor(t, new int[][] { { 1, 0, 0 }, { 1, 4, 4 } });
assertEquals(26, m.get(0, 0), 0);
assertEquals(27, m.get(0, 1), 0);
assertEquals(36, m.get(2, 0), 0);
assertEquals(38, m.get(2, 2), 0);
m = TensorFactory.tensor(t, new int[][] { { 1, 0, 0 }, { 1, 4, 4 } });
assertEquals(26, m.get(0, 0), 0);
assertEquals(27, m.get(0, 1), 0);
assertEquals(36, m.get(2, 0), 0);
assertEquals(38, m.get(2, 2), 0);
m = TensorFactory.tensor(t, new int[][] { { 0, 0, 1 }, { 4, 4, 1 } });
assertEquals(2, m.get(0, 0), 0);
assertEquals(7, m.get(0, 1), 0);
assertEquals(12, m.get(0, 2), 0);
assertEquals(27, m.get(1, 0), 0);
assertEquals(32, m.get(1, 1), 0);
assertEquals(37, m.get(1, 2), 0);
m = TensorFactory.tensor(t, new int[][] { { 2, 2, 1 }, { 3, 3, 1 } });
assertEquals(62, m.get(0, 0), 0);
assertEquals(67, m.get(0, 1), 0);
assertEquals(92, m.get(1, 1), 0);
Iterator<Integer> it = m.iterator();
assertEquals(62, m.getElements()[it.next()], 0);
assertEquals(67, m.getElements()[it.next()], 0);
it.next();
assertEquals(92, m.getElements()[it.next()], 0);
it = m.iterator(new int[][] { { 1, 0 }, { 1, 1 } });
it.next();
assertEquals(92, m.getElements()[it.next()], 0);
m = TensorFactory.tensor(4, 4);
for (int i = 0; i < m.getElements().length; i++) {
m.getElements()[i] = i + 1;
}
Matrix m2 = TensorFactory.tensor(m, new int[][] { { 1, 1 }, { 2, 2 } });
assertEquals(6, m2.get(0, 0), 0);
assertEquals(7, m2.get(0, 1), 0);
assertEquals(10, m2.get(1, 0), 0);
assertEquals(11, m2.get(1, 1), 0);
}