Search in sources :

Example 1 with ROCMultiClass

use of org.deeplearning4j.eval.ROCMultiClass in project deeplearning4j by deeplearning4j.

the class KerasModelEndToEndTest method compareMulticlassAUC.

public static void compareMulticlassAUC(String label, INDArray target, INDArray a, INDArray b, int nbClasses, double eps) {
    ROCMultiClass evalA = new ROCMultiClass(100);
    evalA.eval(target, a);
    double avgAucA = evalA.calculateAverageAUC();
    ROCMultiClass evalB = new ROCMultiClass(100);
    evalB.eval(target, b);
    double avgAucB = evalB.calculateAverageAUC();
    assertEquals(avgAucA, avgAucB, EPS);
    double[] aucA = new double[nbClasses];
    double[] aucB = new double[nbClasses];
    for (int i = 0; i < nbClasses; i++) {
        aucA[i] = evalA.calculateAUC(i);
        aucB[i] = evalB.calculateAUC(i);
    }
    assertArrayEquals(aucA, aucB, EPS);
}
Also used : ROCMultiClass(org.deeplearning4j.eval.ROCMultiClass)

Example 2 with ROCMultiClass

use of org.deeplearning4j.eval.ROCMultiClass in project deeplearning4j by deeplearning4j.

the class TestSparkMultiLayerParameterAveraging method testROCMultiClass.

@Test
public void testROCMultiClass() {
    int nArrays = 100;
    int minibatch = 64;
    int steps = 20;
    int nIn = 5;
    int nOut = 3;
    int layerSize = 10;
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(layerSize).build()).layer(1, new OutputLayer.Builder().nIn(layerSize).nOut(nOut).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();
    Nd4j.getRandom().setSeed(12345);
    Random r = new Random(12345);
    ROCMultiClass local = new ROCMultiClass(steps);
    List<DataSet> dsList = new ArrayList<>();
    for (int i = 0; i < nArrays; i++) {
        INDArray features = Nd4j.rand(minibatch, nIn);
        INDArray p = net.output(features);
        INDArray l = Nd4j.zeros(minibatch, nOut);
        for (int j = 0; j < minibatch; j++) {
            l.putScalar(j, r.nextInt(nOut), 1.0);
        }
        local.eval(l, p);
        dsList.add(new DataSet(features, l));
    }
    SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, net, null);
    JavaRDD<DataSet> rdd = sc.parallelize(dsList);
    ROCMultiClass sparkROC = sparkNet.evaluateROCMultiClass(rdd, steps, 32);
    for (int i = 0; i < nOut; i++) {
        assertEquals(sparkROC.calculateAUC(i), sparkROC.calculateAUC(i), 1e-6);
        double[][] arrLocal = local.getResultsAsArray(i);
        double[][] arrSpark = sparkROC.getResultsAsArray(i);
        assertArrayEquals(arrLocal[0], arrSpark[0], 1e-6);
        assertArrayEquals(arrLocal[1], arrSpark[1], 1e-6);
    }
}
Also used : OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) ROCMultiClass(org.deeplearning4j.eval.ROCMultiClass) MultiDataSet(org.nd4j.linalg.dataset.MultiDataSet) DataSet(org.nd4j.linalg.dataset.DataSet) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) LabeledPoint(org.apache.spark.mllib.regression.LabeledPoint) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) INDArray(org.nd4j.linalg.api.ndarray.INDArray) SparkDl4jMultiLayer(org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) BaseSparkTest(org.deeplearning4j.spark.BaseSparkTest) Test(org.junit.Test)

Aggregations

ROCMultiClass (org.deeplearning4j.eval.ROCMultiClass)2 LabeledPoint (org.apache.spark.mllib.regression.LabeledPoint)1 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)1 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)1 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)1 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)1 BaseSparkTest (org.deeplearning4j.spark.BaseSparkTest)1 SparkDl4jMultiLayer (org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer)1 Test (org.junit.Test)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1 DataSet (org.nd4j.linalg.dataset.DataSet)1 MultiDataSet (org.nd4j.linalg.dataset.MultiDataSet)1