use of org.nd4j.linalg.api.ndarray.INDArray in project deeplearning4j by deeplearning4j.
the class RegressionEvalTest method testRegressionEvaluationMerging.
@Test
public void testRegressionEvaluationMerging() {
Nd4j.getRandom().setSeed(12345);
int nRows = 20;
int nCols = 3;
int numMinibatches = 5;
int nEvalInstances = 4;
List<RegressionEvaluation> list = new ArrayList<>();
RegressionEvaluation single = new RegressionEvaluation(nCols);
for (int i = 0; i < nEvalInstances; i++) {
list.add(new RegressionEvaluation(nCols));
for (int j = 0; j < numMinibatches; j++) {
INDArray p = Nd4j.rand(nRows, nCols);
INDArray act = Nd4j.rand(nRows, nCols);
single.eval(act, p);
list.get(i).eval(act, p);
}
}
RegressionEvaluation merged = list.get(0);
for (int i = 1; i < nEvalInstances; i++) {
merged.merge(list.get(i));
}
double prec = 1e-6;
for (int i = 0; i < nCols; i++) {
assertEquals(single.correlationR2(i), merged.correlationR2(i), prec);
assertEquals(single.meanAbsoluteError(i), merged.meanAbsoluteError(i), prec);
assertEquals(single.meanSquaredError(i), merged.meanSquaredError(i), prec);
assertEquals(single.relativeSquaredError(i), merged.relativeSquaredError(i), prec);
assertEquals(single.rootMeanSquaredError(i), merged.rootMeanSquaredError(i), prec);
}
}
use of org.nd4j.linalg.api.ndarray.INDArray in project deeplearning4j by deeplearning4j.
the class RegressionEvalTest method testPerfectPredictions.
@Test
public void testPerfectPredictions() {
int nCols = 5;
int nTestArrays = 100;
int valuesPerTestArray = 3;
RegressionEvaluation eval = new RegressionEvaluation(nCols);
for (int i = 0; i < nTestArrays; i++) {
INDArray rand = Nd4j.rand(valuesPerTestArray, nCols);
eval.eval(rand, rand);
}
System.out.println(eval.stats());
for (int i = 0; i < nCols; i++) {
assertEquals(0.0, eval.meanSquaredError(i), 1e-6);
assertEquals(0.0, eval.meanAbsoluteError(i), 1e-6);
assertEquals(0.0, eval.rootMeanSquaredError(i), 1e-6);
assertEquals(0.0, eval.relativeSquaredError(i), 1e-6);
assertEquals(1.0, eval.correlationR2(i), 1e-6);
}
}
use of org.nd4j.linalg.api.ndarray.INDArray in project deeplearning4j by deeplearning4j.
the class EvalTest method testEval2.
@Test
public void testEval2() {
//Confusion matrix:
//actual 0 20 3
//actual 1 10 5
Evaluation evaluation = new Evaluation(Arrays.asList("class0", "class1"));
INDArray predicted0 = Nd4j.create(new double[] { 1, 0 });
INDArray predicted1 = Nd4j.create(new double[] { 0, 1 });
INDArray actual0 = Nd4j.create(new double[] { 1, 0 });
INDArray actual1 = Nd4j.create(new double[] { 0, 1 });
for (int i = 0; i < 20; i++) {
evaluation.eval(actual0, predicted0);
}
for (int i = 0; i < 3; i++) {
evaluation.eval(actual0, predicted1);
}
for (int i = 0; i < 10; i++) {
evaluation.eval(actual1, predicted0);
}
for (int i = 0; i < 5; i++) {
evaluation.eval(actual1, predicted1);
}
assertEquals(20, evaluation.truePositives().get(0), 0);
assertEquals(3, evaluation.falseNegatives().get(0), 0);
assertEquals(10, evaluation.falsePositives().get(0), 0);
assertEquals(5, evaluation.trueNegatives().get(0), 0);
assertEquals((20.0 + 5) / (20 + 3 + 10 + 5), evaluation.accuracy(), 1e-6);
System.out.println(evaluation.confusionToString());
}
use of org.nd4j.linalg.api.ndarray.INDArray in project deeplearning4j by deeplearning4j.
the class EvalTest method testSingleClassBinaryClassification.
@Test
public void testSingleClassBinaryClassification() {
Evaluation eval = new Evaluation(1);
INDArray zero = Nd4j.create(1);
INDArray one = Nd4j.ones(1);
//One incorrect, three correct
eval.eval(one, zero);
eval.eval(one, one);
eval.eval(one, one);
eval.eval(zero, zero);
System.out.println(eval.stats());
assertEquals(0.75, eval.accuracy(), 1e-6);
assertEquals(4, eval.getNumRowCounter());
assertEquals(1, (int) eval.truePositives().get(0));
assertEquals(2, (int) eval.truePositives().get(1));
assertEquals(1, (int) eval.falseNegatives().get(1));
}
use of org.nd4j.linalg.api.ndarray.INDArray in project deeplearning4j by deeplearning4j.
the class EvalTest method testEvalMethods.
@Test
public void testEvalMethods() {
//Check eval(int,int) vs. eval(INDArray,INDArray)
Evaluation e1 = new Evaluation(4);
Evaluation e2 = new Evaluation(4);
INDArray i0 = Nd4j.create(new double[] { 1, 0, 0, 0 });
INDArray i1 = Nd4j.create(new double[] { 0, 1, 0, 0 });
INDArray i2 = Nd4j.create(new double[] { 0, 0, 1, 0 });
INDArray i3 = Nd4j.create(new double[] { 0, 0, 0, 1 });
//order: actual, predicted
e1.eval(i0, i0);
//order: predicted, actual
e2.eval(0, 0);
e1.eval(i0, i2);
e2.eval(2, 0);
e1.eval(i0, i2);
e2.eval(2, 0);
e1.eval(i1, i2);
e2.eval(2, 1);
e1.eval(i3, i3);
e2.eval(3, 3);
e1.eval(i3, i0);
e2.eval(0, 3);
e1.eval(i3, i0);
e2.eval(0, 3);
ConfusionMatrix<Integer> cm = e1.getConfusionMatrix();
//Order: actual, predicted
assertEquals(1, cm.getCount(0, 0));
assertEquals(2, cm.getCount(0, 2));
assertEquals(1, cm.getCount(1, 2));
assertEquals(1, cm.getCount(3, 3));
assertEquals(2, cm.getCount(3, 0));
System.out.println(e1.stats());
System.out.println(e2.stats());
assertEquals(e1.stats(), e2.stats());
}
Aggregations