use of org.deeplearning4j.eval.ROCMultiClass in project deeplearning4j by deeplearning4j.
the class KerasModelEndToEndTest method compareMulticlassAUC.
public static void compareMulticlassAUC(String label, INDArray target, INDArray a, INDArray b, int nbClasses, double eps) {
ROCMultiClass evalA = new ROCMultiClass(100);
evalA.eval(target, a);
double avgAucA = evalA.calculateAverageAUC();
ROCMultiClass evalB = new ROCMultiClass(100);
evalB.eval(target, b);
double avgAucB = evalB.calculateAverageAUC();
assertEquals(avgAucA, avgAucB, EPS);
double[] aucA = new double[nbClasses];
double[] aucB = new double[nbClasses];
for (int i = 0; i < nbClasses; i++) {
aucA[i] = evalA.calculateAUC(i);
aucB[i] = evalB.calculateAUC(i);
}
assertArrayEquals(aucA, aucB, EPS);
}
use of org.deeplearning4j.eval.ROCMultiClass in project deeplearning4j by deeplearning4j.
the class TestSparkMultiLayerParameterAveraging method testROCMultiClass.
@Test
public void testROCMultiClass() {
int nArrays = 100;
int minibatch = 64;
int steps = 20;
int nIn = 5;
int nOut = 3;
int layerSize = 10;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(layerSize).build()).layer(1, new OutputLayer.Builder().nIn(layerSize).nOut(nOut).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
Nd4j.getRandom().setSeed(12345);
Random r = new Random(12345);
ROCMultiClass local = new ROCMultiClass(steps);
List<DataSet> dsList = new ArrayList<>();
for (int i = 0; i < nArrays; i++) {
INDArray features = Nd4j.rand(minibatch, nIn);
INDArray p = net.output(features);
INDArray l = Nd4j.zeros(minibatch, nOut);
for (int j = 0; j < minibatch; j++) {
l.putScalar(j, r.nextInt(nOut), 1.0);
}
local.eval(l, p);
dsList.add(new DataSet(features, l));
}
SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, net, null);
JavaRDD<DataSet> rdd = sc.parallelize(dsList);
ROCMultiClass sparkROC = sparkNet.evaluateROCMultiClass(rdd, steps, 32);
for (int i = 0; i < nOut; i++) {
assertEquals(sparkROC.calculateAUC(i), sparkROC.calculateAUC(i), 1e-6);
double[][] arrLocal = local.getResultsAsArray(i);
double[][] arrSpark = sparkROC.getResultsAsArray(i);
assertArrayEquals(arrLocal[0], arrSpark[0], 1e-6);
assertArrayEquals(arrLocal[1], arrSpark[1], 1e-6);
}
}
Aggregations