Search in sources :

Example 31 with ScoreIterationListener

use of org.deeplearning4j.optimize.listeners.ScoreIterationListener in project deeplearning4j by deeplearning4j.

the class BackPropMLPTest method testMLPTrivial.

@Test
public void testMLPTrivial() {
    //Simplest possible case: 1 hidden layer, 1 hidden neuron, batch size of 1.
    MultiLayerNetwork network = new MultiLayerNetwork(getIrisMLPSimpleConfig(new int[] { 1 }, Activation.SIGMOID));
    network.setListeners(new ScoreIterationListener(1));
    network.init();
    DataSetIterator iter = new IrisDataSetIterator(1, 10);
    while (iter.hasNext()) network.fit(iter.next());
}
Also used : IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) Test(org.junit.Test)

Example 32 with ScoreIterationListener

use of org.deeplearning4j.optimize.listeners.ScoreIterationListener in project deeplearning4j by deeplearning4j.

the class GravesLSTMOutputTest method testSameLabelsOutputWithTBPTT.

@Test
public void testSameLabelsOutputWithTBPTT() {
    MultiLayerNetwork network = new MultiLayerNetwork(getNetworkConf(40, true));
    network.init();
    network.setListeners(new ScoreIterationListener(1));
    for (int i = 0; i < window / 100; i++) {
        INDArray d = data.get(NDArrayIndex.interval(100 * i, 100 * (i + 1)), NDArrayIndex.all());
        network.fit(reshapeInput(d.dup()), reshapeInput(d.dup()));
    }
    Evaluation ev = eval(network);
}
Also used : Evaluation(org.deeplearning4j.eval.Evaluation) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) Test(org.junit.Test)

Example 33 with ScoreIterationListener

use of org.deeplearning4j.optimize.listeners.ScoreIterationListener in project deeplearning4j by deeplearning4j.

the class BackTrackLineSearchTest method testBackTrackLineLBFGS.

@Test
public void testBackTrackLineLBFGS() {
    OptimizationAlgorithm optimizer = OptimizationAlgorithm.LBFGS;
    DataSet data = irisIter.next();
    data.normalizeZeroMeanZeroUnitVariance();
    MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.RELU, 5, optimizer));
    network.init();
    IterationListener listener = new ScoreIterationListener(1);
    network.setListeners(Collections.singletonList(listener));
    double oldScore = network.score(data);
    network.fit(data.getFeatureMatrix(), data.getLabels());
    double score = network.score();
    assertTrue(score < oldScore);
}
Also used : OptimizationAlgorithm(org.deeplearning4j.nn.api.OptimizationAlgorithm) DataSet(org.nd4j.linalg.dataset.DataSet) IterationListener(org.deeplearning4j.optimize.api.IterationListener) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) Test(org.junit.Test)

Example 34 with ScoreIterationListener

use of org.deeplearning4j.optimize.listeners.ScoreIterationListener in project deeplearning4j by deeplearning4j.

the class BackTrackLineSearchTest method testBackTrackLineHessian.

@Test(expected = Exception.class)
public void testBackTrackLineHessian() {
    OptimizationAlgorithm optimizer = OptimizationAlgorithm.HESSIAN_FREE;
    DataSet data = irisIter.next();
    MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.RELU, 100, optimizer));
    network.init();
    IterationListener listener = new ScoreIterationListener(1);
    network.setListeners(Collections.singletonList(listener));
    network.fit(data.getFeatureMatrix(), data.getLabels());
}
Also used : OptimizationAlgorithm(org.deeplearning4j.nn.api.OptimizationAlgorithm) DataSet(org.nd4j.linalg.dataset.DataSet) IterationListener(org.deeplearning4j.optimize.api.IterationListener) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) Test(org.junit.Test)

Example 35 with ScoreIterationListener

use of org.deeplearning4j.optimize.listeners.ScoreIterationListener in project deeplearning4j by deeplearning4j.

the class TestPlayUI method testUIMultipleSessions.

@Test
@Ignore
public void testUIMultipleSessions() throws Exception {
    for (int session = 0; session < 3; session++) {
        StatsStorage ss = new InMemoryStatsStorage();
        UIServer uiServer = UIServer.getInstance();
        uiServer.attach(ss);
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).list().layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(4).build()).layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(4).nOut(3).build()).pretrain(false).backprop(true).build();
        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();
        net.setListeners(new StatsListener(ss), new ScoreIterationListener(1));
        DataSetIterator iter = new IrisDataSetIterator(150, 150);
        for (int i = 0; i < 20; i++) {
            net.fit(iter);
            Thread.sleep(100);
        }
    }
    Thread.sleep(1000000);
}
Also used : OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) InMemoryStatsStorage(org.deeplearning4j.ui.storage.InMemoryStatsStorage) StatsStorage(org.deeplearning4j.api.storage.StatsStorage) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) UIServer(org.deeplearning4j.ui.api.UIServer) StatsListener(org.deeplearning4j.ui.stats.StatsListener) InMemoryStatsStorage(org.deeplearning4j.ui.storage.InMemoryStatsStorage) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) Ignore(org.junit.Ignore) Test(org.junit.Test)

Aggregations

ScoreIterationListener (org.deeplearning4j.optimize.listeners.ScoreIterationListener)76 Test (org.junit.Test)75 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)44 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)43 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)41 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)39 DataSet (org.nd4j.linalg.dataset.DataSet)37 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)35 InMemoryModelSaver (org.deeplearning4j.earlystopping.saver.InMemoryModelSaver)26 MaxEpochsTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition)26 INDArray (org.nd4j.linalg.api.ndarray.INDArray)23 MaxTimeIterationTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition)22 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)21 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)17 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)17 IterationListener (org.deeplearning4j.optimize.api.IterationListener)15 MnistDataSetIterator (org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator)13 MaxScoreIterationTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxScoreIterationTerminationCondition)13 IEarlyStoppingTrainer (org.deeplearning4j.earlystopping.trainer.IEarlyStoppingTrainer)13 EarlyStoppingConfiguration (org.deeplearning4j.earlystopping.EarlyStoppingConfiguration)12