Search in sources :

Example 1 with ROC

use of org.deeplearning4j.eval.ROC in project deeplearning4j by deeplearning4j.

the class TestSparkMultiLayerParameterAveraging method testROC.

@Test
public void testROC() {
    int nArrays = 100;
    int minibatch = 64;
    int steps = 20;
    int nIn = 5;
    int nOut = 2;
    int layerSize = 10;
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(layerSize).build()).layer(1, new OutputLayer.Builder().nIn(layerSize).nOut(nOut).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();
    Nd4j.getRandom().setSeed(12345);
    Random r = new Random(12345);
    ROC local = new ROC(steps);
    List<DataSet> dsList = new ArrayList<>();
    for (int i = 0; i < nArrays; i++) {
        INDArray features = Nd4j.rand(minibatch, nIn);
        INDArray p = net.output(features);
        INDArray l = Nd4j.zeros(minibatch, 2);
        for (int j = 0; j < minibatch; j++) {
            l.putScalar(j, r.nextInt(2), 1.0);
        }
        local.eval(l, p);
        dsList.add(new DataSet(features, l));
    }
    SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, net, null);
    JavaRDD<DataSet> rdd = sc.parallelize(dsList);
    ROC sparkROC = sparkNet.evaluateROC(rdd, steps, 32);
    assertEquals(sparkROC.calculateAUC(), sparkROC.calculateAUC(), 1e-6);
    double[][] arrLocal = local.getResultsAsArray();
    double[][] arrSpark = sparkROC.getResultsAsArray();
    assertArrayEquals(arrLocal[0], arrSpark[0], 1e-6);
    assertArrayEquals(arrLocal[1], arrSpark[1], 1e-6);
}
Also used : OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) MultiDataSet(org.nd4j.linalg.dataset.MultiDataSet) DataSet(org.nd4j.linalg.dataset.DataSet) ROC(org.deeplearning4j.eval.ROC) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) LabeledPoint(org.apache.spark.mllib.regression.LabeledPoint) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) INDArray(org.nd4j.linalg.api.ndarray.INDArray) SparkDl4jMultiLayer(org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) BaseSparkTest(org.deeplearning4j.spark.BaseSparkTest) Test(org.junit.Test)

Aggregations

LabeledPoint (org.apache.spark.mllib.regression.LabeledPoint)1 ROC (org.deeplearning4j.eval.ROC)1 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)1 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)1 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)1 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)1 BaseSparkTest (org.deeplearning4j.spark.BaseSparkTest)1 SparkDl4jMultiLayer (org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer)1 Test (org.junit.Test)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1 DataSet (org.nd4j.linalg.dataset.DataSet)1 MultiDataSet (org.nd4j.linalg.dataset.MultiDataSet)1