Search in sources :

Example 1 with ListDataSetIterator

use of org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator in project deeplearning4j by deeplearning4j.

the class BatchNormalizationTest method checkMeanVarianceEstimateCNN.

@Test
public void checkMeanVarianceEstimateCNN() throws Exception {
    Nd4j.getRandom().setSeed(12345);
    //Check that the internal global mean/variance estimate is approximately correct
    //First, Mnist data as 2d input (NOT taking into account convolution property)
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).updater(Updater.RMSPROP).seed(12345).list().layer(0, new BatchNormalization.Builder().nIn(3).nOut(3).eps(1e-5).decay(0.95).build()).layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).weightInit(WeightInit.XAVIER).activation(Activation.IDENTITY).nOut(10).build()).backprop(true).pretrain(false).setInputType(InputType.convolutional(5, 5, 3)).build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();
    int minibatch = 32;
    List<DataSet> list = new ArrayList<>();
    for (int i = 0; i < 100; i++) {
        list.add(new DataSet(Nd4j.rand(new int[] { minibatch, 3, 5, 5 }), Nd4j.rand(minibatch, 10)));
    }
    DataSetIterator iter = new ListDataSetIterator(list);
    INDArray expMean = Nd4j.valueArrayOf(new int[] { 1, 3 }, 0.5);
    //Expected variance of U(0,1) distribution: 1/12 * (1-0)^2 = 0.0833
    INDArray expVar = Nd4j.valueArrayOf(new int[] { 1, 3 }, 1 / 12.0);
    for (int i = 0; i < 10; i++) {
        iter.reset();
        net.fit(iter);
    }
    INDArray estMean = net.getLayer(0).getParam(BatchNormalizationParamInitializer.GLOBAL_MEAN);
    INDArray estVar = net.getLayer(0).getParam(BatchNormalizationParamInitializer.GLOBAL_VAR);
    float[] fMeanExp = expMean.data().asFloat();
    float[] fMeanAct = estMean.data().asFloat();
    float[] fVarExp = expVar.data().asFloat();
    float[] fVarAct = estVar.data().asFloat();
    //        System.out.println("Mean vs. estimated mean:");
    //        System.out.println(Arrays.toString(fMeanExp));
    //        System.out.println(Arrays.toString(fMeanAct));
    //
    //        System.out.println("Var vs. estimated var:");
    //        System.out.println(Arrays.toString(fVarExp));
    //        System.out.println(Arrays.toString(fVarAct));
    assertArrayEquals(fMeanExp, fMeanAct, 0.01f);
    assertArrayEquals(fVarExp, fVarAct, 0.01f);
}
Also used : ListDataSetIterator(org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator) DataSet(org.nd4j.linalg.dataset.DataSet) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) BatchNormalization(org.deeplearning4j.nn.conf.layers.BatchNormalization) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) INDArray(org.nd4j.linalg.api.ndarray.INDArray) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) MnistDataSetIterator(org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator) ListDataSetIterator(org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator) Test(org.junit.Test)

Example 2 with ListDataSetIterator

use of org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator in project deeplearning4j by deeplearning4j.

the class BatchNormalizationTest method checkMeanVarianceEstimate.

@Test
public void checkMeanVarianceEstimate() throws Exception {
    Nd4j.getRandom().setSeed(12345);
    //Check that the internal global mean/variance estimate is approximately correct
    //First, Mnist data as 2d input (NOT taking into account convolution property)
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).updater(Updater.RMSPROP).seed(12345).list().layer(0, new BatchNormalization.Builder().nIn(10).nOut(10).eps(1e-5).decay(0.95).build()).layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).weightInit(WeightInit.XAVIER).activation(Activation.IDENTITY).nIn(10).nOut(10).build()).backprop(true).pretrain(false).build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();
    int minibatch = 32;
    List<DataSet> list = new ArrayList<>();
    for (int i = 0; i < 200; i++) {
        list.add(new DataSet(Nd4j.rand(minibatch, 10), Nd4j.rand(minibatch, 10)));
    }
    DataSetIterator iter = new ListDataSetIterator(list);
    INDArray expMean = Nd4j.valueArrayOf(new int[] { 1, 10 }, 0.5);
    //Expected variance of U(0,1) distribution: 1/12 * (1-0)^2 = 0.0833
    INDArray expVar = Nd4j.valueArrayOf(new int[] { 1, 10 }, 1 / 12.0);
    for (int i = 0; i < 10; i++) {
        iter.reset();
        net.fit(iter);
    }
    INDArray estMean = net.getLayer(0).getParam(BatchNormalizationParamInitializer.GLOBAL_MEAN);
    INDArray estVar = net.getLayer(0).getParam(BatchNormalizationParamInitializer.GLOBAL_VAR);
    float[] fMeanExp = expMean.data().asFloat();
    float[] fMeanAct = estMean.data().asFloat();
    float[] fVarExp = expVar.data().asFloat();
    float[] fVarAct = estVar.data().asFloat();
    //        System.out.println("Mean vs. estimated mean:");
    //        System.out.println(Arrays.toString(fMeanExp));
    //        System.out.println(Arrays.toString(fMeanAct));
    //
    //        System.out.println("Var vs. estimated var:");
    //        System.out.println(Arrays.toString(fVarExp));
    //        System.out.println(Arrays.toString(fVarAct));
    assertArrayEquals(fMeanExp, fMeanAct, 0.02f);
    assertArrayEquals(fVarExp, fVarAct, 0.02f);
}
Also used : ListDataSetIterator(org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator) DataSet(org.nd4j.linalg.dataset.DataSet) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) INDArray(org.nd4j.linalg.api.ndarray.INDArray) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) MnistDataSetIterator(org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator) ListDataSetIterator(org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator) Test(org.junit.Test)

Example 3 with ListDataSetIterator

use of org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator in project deeplearning4j by deeplearning4j.

the class TestCompGraphCNN method getDS.

protected static DataSetIterator getDS() {
    List<DataSet> list = new ArrayList<>(5);
    for (int i = 0; i < 5; i++) {
        INDArray f = Nd4j.create(1, 32 * 32 * 3);
        INDArray l = Nd4j.create(1, 10);
        l.putScalar(i, 1.0);
        list.add(new DataSet(f, l));
    }
    return new ListDataSetIterator(list, 5);
}
Also used : ListDataSetIterator(org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator) INDArray(org.nd4j.linalg.api.ndarray.INDArray) DataSet(org.nd4j.linalg.dataset.DataSet) ArrayList(java.util.ArrayList)

Example 4 with ListDataSetIterator

use of org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator in project deeplearning4j by deeplearning4j.

the class TestEarlyStopping method testMinImprovementNEpochsTermination.

@Test
public void testMinImprovementNEpochsTermination() {
    //Idea: terminate training if score (test set loss) does not improve more than minImprovement for 5 consecutive epochs
    //Simulate this by setting LR = 0.0
    Random rng = new Random(123);
    Nd4j.getRandom().setSeed(12345);
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(123).iterations(10).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(0.0).updater(Updater.NESTEROVS).momentum(0.9).list().layer(0, new DenseLayer.Builder().nIn(1).nOut(20).weightInit(WeightInit.XAVIER).activation(Activation.TANH).build()).layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).weightInit(WeightInit.XAVIER).activation(Activation.IDENTITY).weightInit(WeightInit.XAVIER).nIn(20).nOut(1).build()).pretrain(false).backprop(true).build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.setListeners(new ScoreIterationListener(1));
    int nSamples = 100;
    //Generate the training data
    INDArray x = Nd4j.linspace(-10, 10, nSamples).reshape(nSamples, 1);
    INDArray y = Nd4j.getExecutioner().execAndReturn(new Sin(x.dup()));
    DataSet allData = new DataSet(x, y);
    List<DataSet> list = allData.asList();
    Collections.shuffle(list, rng);
    DataSetIterator training = new ListDataSetIterator(list, nSamples);
    double minImprovement = 0.0009;
    EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<MultiLayerNetwork> esConf = new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>().epochTerminationConditions(new MaxEpochsTerminationCondition(1000), //Go on for max 5 epochs without any improvements that are greater than minImprovement
    new ScoreImprovementEpochTerminationCondition(5, minImprovement)).iterationTerminationConditions(new MaxTimeIterationTerminationCondition(3, TimeUnit.MINUTES)).scoreCalculator(new DataSetLossCalculator(training, true)).modelSaver(saver).build();
    IEarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, net, training);
    EarlyStoppingResult result = trainer.fit();
    assertEquals(6, result.getTotalEpochs());
    assertEquals(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, result.getTerminationReason());
    String expDetails = new ScoreImprovementEpochTerminationCondition(5, minImprovement).toString();
    assertEquals(expDetails, result.getTerminationDetails());
}
Also used : ListDataSetIterator(org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator) DataSet(org.nd4j.linalg.dataset.DataSet) ScoreImprovementEpochTerminationCondition(org.deeplearning4j.earlystopping.termination.ScoreImprovementEpochTerminationCondition) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) Random(java.util.Random) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) InMemoryModelSaver(org.deeplearning4j.earlystopping.saver.InMemoryModelSaver) MaxEpochsTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition) IEarlyStoppingTrainer(org.deeplearning4j.earlystopping.trainer.IEarlyStoppingTrainer) EarlyStoppingTrainer(org.deeplearning4j.earlystopping.trainer.EarlyStoppingTrainer) INDArray(org.nd4j.linalg.api.ndarray.INDArray) IEarlyStoppingTrainer(org.deeplearning4j.earlystopping.trainer.IEarlyStoppingTrainer) Sin(org.nd4j.linalg.api.ops.impl.transforms.Sin) DataSetLossCalculator(org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) ListDataSetIterator(org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator) MaxTimeIterationTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition) Test(org.junit.Test)

Example 5 with ListDataSetIterator

use of org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator in project deeplearning4j by deeplearning4j.

the class EvalTest method testIris.

@Test
public void testIris() {
    // Network config
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).iterations(1).seed(42).learningRate(1e-6).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(2).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()).build();
    // Instantiate model
    MultiLayerNetwork model = new MultiLayerNetwork(conf);
    model.init();
    model.setListeners(Arrays.asList((IterationListener) new ScoreIterationListener(1)));
    // Train-test split
    DataSetIterator iter = new IrisDataSetIterator(150, 150);
    DataSet next = iter.next();
    next.shuffle();
    SplitTestAndTrain trainTest = next.splitTestAndTrain(5, new Random(42));
    // Train
    DataSet train = trainTest.getTrain();
    train.normalizeZeroMeanZeroUnitVariance();
    // Test
    DataSet test = trainTest.getTest();
    test.normalizeZeroMeanZeroUnitVariance();
    INDArray testFeature = test.getFeatureMatrix();
    INDArray testLabel = test.getLabels();
    // Fitting model
    model.fit(train);
    // Get predictions from test feature
    INDArray testPredictedLabel = model.output(testFeature);
    // Eval with class number
    //// Specify class num here
    Evaluation eval = new Evaluation(3);
    eval.eval(testLabel, testPredictedLabel);
    double eval1F1 = eval.f1();
    double eval1Acc = eval.accuracy();
    // Eval without class number
    //// No class num
    Evaluation eval2 = new Evaluation();
    eval2.eval(testLabel, testPredictedLabel);
    double eval2F1 = eval2.f1();
    double eval2Acc = eval2.accuracy();
    //Assert the two implementations give same f1 and accuracy (since one batch)
    assertTrue(eval1F1 == eval2F1 && eval1Acc == eval2Acc);
    Evaluation evalViaMethod = model.evaluate(new ListDataSetIterator(Collections.singletonList(test)));
    checkEvaluationEquality(eval, evalViaMethod);
    System.out.println(eval.getConfusionMatrix().toString());
    System.out.println(eval.getConfusionMatrix().toCSV());
    System.out.println(eval.getConfusionMatrix().toHTML());
    System.out.println(eval.confusionToString());
}
Also used : OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) ListDataSetIterator(org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSet(org.nd4j.linalg.dataset.DataSet) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) INDArray(org.nd4j.linalg.api.ndarray.INDArray) IterationListener(org.deeplearning4j.optimize.api.IterationListener) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) ListDataSetIterator(org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator) RecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator) SplitTestAndTrain(org.nd4j.linalg.dataset.SplitTestAndTrain) Test(org.junit.Test)

Aggregations

ListDataSetIterator (org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator)5 INDArray (org.nd4j.linalg.api.ndarray.INDArray)5 DataSet (org.nd4j.linalg.dataset.DataSet)5 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)4 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)4 Test (org.junit.Test)4 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)4 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)2 MnistDataSetIterator (org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator)2 ScoreIterationListener (org.deeplearning4j.optimize.listeners.ScoreIterationListener)2 ArrayList (java.util.ArrayList)1 Random (java.util.Random)1 RecordReaderDataSetIterator (org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator)1 InMemoryModelSaver (org.deeplearning4j.earlystopping.saver.InMemoryModelSaver)1 DataSetLossCalculator (org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator)1 MaxEpochsTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition)1 MaxTimeIterationTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition)1 ScoreImprovementEpochTerminationCondition (org.deeplearning4j.earlystopping.termination.ScoreImprovementEpochTerminationCondition)1 EarlyStoppingTrainer (org.deeplearning4j.earlystopping.trainer.EarlyStoppingTrainer)1 IEarlyStoppingTrainer (org.deeplearning4j.earlystopping.trainer.IEarlyStoppingTrainer)1