Search in sources :

Example 6 with OptimizationAlgorithm

use of org.deeplearning4j.nn.api.OptimizationAlgorithm in project deeplearning4j by deeplearning4j.

the class BackTrackLineSearchTest method testBackTrackLineGradientDescent.

///////////////////////////////////////////////////////////////////////////
@Test
public void testBackTrackLineGradientDescent() {
    OptimizationAlgorithm optimizer = OptimizationAlgorithm.LINE_GRADIENT_DESCENT;
    DataSetIterator irisIter = new IrisDataSetIterator(1, 1);
    DataSet data = irisIter.next();
    MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.SIGMOID, 100, optimizer));
    network.init();
    IterationListener listener = new ScoreIterationListener(1);
    network.setListeners(Collections.singletonList(listener));
    double oldScore = network.score(data);
    network.fit(data.getFeatureMatrix(), data.getLabels());
    double score = network.score();
    assertTrue(score < oldScore);
}
Also used : OptimizationAlgorithm(org.deeplearning4j.nn.api.OptimizationAlgorithm) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSet(org.nd4j.linalg.dataset.DataSet) IterationListener(org.deeplearning4j.optimize.api.IterationListener) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) Test(org.junit.Test)

Example 7 with OptimizationAlgorithm

use of org.deeplearning4j.nn.api.OptimizationAlgorithm in project deeplearning4j by deeplearning4j.

the class TestOptimizers method testOptimizersMLP.

@Test
public void testOptimizersMLP() {
    //Check that the score actually decreases over time
    DataSetIterator iter = new IrisDataSetIterator(150, 150);
    OptimizationAlgorithm[] toTest = { OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT, OptimizationAlgorithm.LINE_GRADIENT_DESCENT, OptimizationAlgorithm.CONJUGATE_GRADIENT, OptimizationAlgorithm.LBFGS };
    DataSet ds = iter.next();
    ds.normalizeZeroMeanZeroUnitVariance();
    for (OptimizationAlgorithm oa : toTest) {
        int nIter = 10;
        MultiLayerNetwork network = new MultiLayerNetwork(getMLPConfigIris(oa, nIter));
        network.init();
        double score = network.score(ds);
        assertTrue(score != 0.0 && !Double.isNaN(score));
        if (PRINT_OPT_RESULTS)
            System.out.println("testOptimizersMLP() - " + oa);
        int nCallsToOptimizer = 30;
        double[] scores = new double[nCallsToOptimizer + 1];
        scores[0] = score;
        for (int i = 0; i < nCallsToOptimizer; i++) {
            network.fit(ds);
            double scoreAfter = network.score(ds);
            scores[i + 1] = scoreAfter;
            assertTrue("Score is NaN after optimization", !Double.isNaN(scoreAfter));
            assertTrue("OA= " + oa + ", before= " + score + ", after= " + scoreAfter, scoreAfter <= score);
            score = scoreAfter;
        }
        if (PRINT_OPT_RESULTS)
            System.out.println(oa + " - " + Arrays.toString(scores));
    }
}
Also used : OptimizationAlgorithm(org.deeplearning4j.nn.api.OptimizationAlgorithm) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSet(org.nd4j.linalg.dataset.DataSet) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) Test(org.junit.Test)

Example 8 with OptimizationAlgorithm

use of org.deeplearning4j.nn.api.OptimizationAlgorithm in project deeplearning4j by deeplearning4j.

the class TestDecayPolicies method testUpdatingInConf.

@Test
public void testUpdatingInConf() throws Exception {
    OptimizationAlgorithm[] optimizationAlgorithms = new OptimizationAlgorithm[] { OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT, OptimizationAlgorithm.LINE_GRADIENT_DESCENT, OptimizationAlgorithm.CONJUGATE_GRADIENT, OptimizationAlgorithm.LBFGS };
    for (OptimizationAlgorithm oa : optimizationAlgorithms) {
        Map<Integer, Double> momentumSchedule = new HashMap<>();
        double m = 0.001;
        for (int i = 0; i <= 100; i++) {
            momentumSchedule.put(i, Math.min(m, 0.9999));
            m += 0.001;
        }
        Map<Integer, Double> learningRateSchedule = new HashMap<>();
        double lr = 0.1;
        for (int i = 0; i <= 100; i++) {
            learningRateSchedule.put(i, lr);
            lr *= 0.96;
        }
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(oa).iterations(1).learningRateDecayPolicy(LearningRatePolicy.Schedule).learningRateSchedule(learningRateSchedule).updater(org.deeplearning4j.nn.conf.Updater.NESTEROVS).weightInit(WeightInit.XAVIER).momentum(0.9).momentumAfter(momentumSchedule).regularization(true).l2(0.0001).list().layer(0, new DenseLayer.Builder().nIn(784).nOut(10).build()).layer(1, new OutputLayer.Builder().nIn(10).nOut(10).build()).pretrain(false).backprop(true).build();
        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();
        int last_layer_index = 1;
        DataSetIterator trainIter = new MnistDataSetIterator(64, true, 12345);
        int count = 0;
        while (trainIter.hasNext()) {
            net.fit(trainIter.next());
            // always print the same number (0.1 and 0.9)
            double lrLastLayer = (net.getLayer(last_layer_index)).conf().getLearningRateByParam("W");
            double mLastLayer = (net.getLayer(last_layer_index)).conf().getLayer().getMomentum();
            assertEquals(learningRateSchedule.get(count), lrLastLayer, 1e-6);
            assertEquals(momentumSchedule.get(count), mLastLayer, 1e-6);
            if (count++ >= 100)
                break;
        }
    }
}
Also used : OptimizationAlgorithm(org.deeplearning4j.nn.api.OptimizationAlgorithm) OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) MnistDataSetIterator(org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator) HashMap(java.util.HashMap) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) MnistDataSetIterator(org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator) Test(org.junit.Test)

Aggregations

OptimizationAlgorithm (org.deeplearning4j.nn.api.OptimizationAlgorithm)8 Test (org.junit.Test)8 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)7 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)5 DataSet (org.nd4j.linalg.dataset.DataSet)5 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)5 IterationListener (org.deeplearning4j.optimize.api.IterationListener)4 ScoreIterationListener (org.deeplearning4j.optimize.listeners.ScoreIterationListener)4 HashMap (java.util.HashMap)1 RecordReaderMultiDataSetIterator (org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator)1 MnistDataSetIterator (org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator)1 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)1 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)1 MultiDataSetIterator (org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator)1