Search in sources :

Example 31 with DataSet

use of org.nd4j.linalg.dataset.DataSet in project deeplearning4j by deeplearning4j.

the class TestEarlyStoppingSpark method testTimeTermination.

@Test
public void testTimeTermination() {
    //test termination after max time
    Nd4j.getRandom().setSeed(12345);
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).updater(Updater.SGD).learningRate(1e-6).weightInit(WeightInit.XAVIER).list().layer(0, new OutputLayer.Builder().nIn(4).nOut(3).lossFunction(LossFunctions.LossFunction.MCXENT).build()).pretrain(false).backprop(true).build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.setListeners(new ScoreIterationListener(1));
    JavaRDD<DataSet> irisData = getIris();
    EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<MultiLayerNetwork> esConf = new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>().epochTerminationConditions(new MaxEpochsTerminationCondition(10000)).iterationTerminationConditions(new MaxTimeIterationTerminationCondition(3, TimeUnit.SECONDS), //Initial score is ~2.5
    new MaxScoreIterationTerminationCondition(7.5)).scoreCalculator(new SparkDataSetLossCalculator(irisData, true, sc.sc())).modelSaver(saver).build();
    IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new SparkEarlyStoppingTrainer(getContext().sc(), new ParameterAveragingTrainingMaster(true, 4, 1, 150 / 15, 1, 0), esConf, net, irisData);
    long startTime = System.currentTimeMillis();
    EarlyStoppingResult result = trainer.fit();
    long endTime = System.currentTimeMillis();
    int durationSeconds = (int) (endTime - startTime) / 1000;
    assertTrue("durationSeconds = " + durationSeconds, durationSeconds >= 3);
    assertTrue("durationSeconds = " + durationSeconds, durationSeconds <= 9);
    assertEquals(EarlyStoppingResult.TerminationReason.IterationTerminationCondition, result.getTerminationReason());
    String expDetails = new MaxTimeIterationTerminationCondition(3, TimeUnit.SECONDS).toString();
    assertEquals(expDetails, result.getTerminationDetails());
}
Also used : InMemoryModelSaver(org.deeplearning4j.earlystopping.saver.InMemoryModelSaver) MaxEpochsTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition) DataSet(org.nd4j.linalg.dataset.DataSet) SparkEarlyStoppingTrainer(org.deeplearning4j.spark.earlystopping.SparkEarlyStoppingTrainer) SparkDataSetLossCalculator(org.deeplearning4j.spark.earlystopping.SparkDataSetLossCalculator) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) ParameterAveragingTrainingMaster(org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster) EarlyStoppingResult(org.deeplearning4j.earlystopping.EarlyStoppingResult) EarlyStoppingConfiguration(org.deeplearning4j.earlystopping.EarlyStoppingConfiguration) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) MaxScoreIterationTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxScoreIterationTerminationCondition) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) MaxTimeIterationTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition) Test(org.junit.Test)

Example 32 with DataSet

use of org.nd4j.linalg.dataset.DataSet in project deeplearning4j by deeplearning4j.

the class LoadSerializedDataSetFunction method call.

@Override
public DataSet call(PortableDataStream pds) throws Exception {
    try (InputStream is = pds.open()) {
        DataSet d = new DataSet();
        d.load(is);
        return d;
    }
}
Also used : DataSet(org.nd4j.linalg.dataset.DataSet) InputStream(java.io.InputStream)

Example 33 with DataSet

use of org.nd4j.linalg.dataset.DataSet in project deeplearning4j by deeplearning4j.

the class SparkDl4jMultiLayer method fitLabeledPoint.

/**
     * Fit a MultiLayerNetwork using Spark MLLib LabeledPoint instances.
     * This will convert the labeled points to the internal DL4J data format and train the model on that
     *
     * @param rdd the rdd to fitDataSet
     * @return the multi layer network that was fitDataSet
     */
public MultiLayerNetwork fitLabeledPoint(JavaRDD<LabeledPoint> rdd) {
    int nLayers = network.getLayerWiseConfigurations().getConfs().size();
    FeedForwardLayer ffl = (FeedForwardLayer) network.getLayerWiseConfigurations().getConf(nLayers - 1).getLayer();
    JavaRDD<DataSet> ds = MLLibUtil.fromLabeledPoint(sc, rdd, ffl.getNOut());
    return fit(ds);
}
Also used : DataSet(org.nd4j.linalg.dataset.DataSet) FeedForwardLayer(org.deeplearning4j.nn.conf.layers.FeedForwardLayer) LabeledPoint(org.apache.spark.mllib.regression.LabeledPoint)

Example 34 with DataSet

use of org.nd4j.linalg.dataset.DataSet in project deeplearning4j by deeplearning4j.

the class IEvaluateFlatMapFunctionAdapter method call.

@Override
public Iterable<T> call(Iterator<DataSet> dataSetIterator) throws Exception {
    if (!dataSetIterator.hasNext()) {
        return Collections.emptyList();
    }
    MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(json.getValue()));
    network.init();
    INDArray val = params.value().unsafeDuplication();
    if (val.length() != network.numParams(false))
        throw new IllegalStateException("Network did not have same number of parameters as the broadcast set parameters");
    network.setParameters(val);
    List<DataSet> collect = new ArrayList<>();
    int totalCount = 0;
    while (dataSetIterator.hasNext()) {
        collect.clear();
        int nExamples = 0;
        while (dataSetIterator.hasNext() && nExamples < evalBatchSize) {
            DataSet next = dataSetIterator.next();
            nExamples += next.numExamples();
            collect.add(next);
        }
        totalCount += nExamples;
        DataSet data = DataSet.merge(collect);
        INDArray out;
        if (data.hasMaskArrays()) {
            out = network.output(data.getFeatureMatrix(), false, data.getFeaturesMaskArray(), data.getLabelsMaskArray());
        } else {
            out = network.output(data.getFeatureMatrix(), false);
        }
        if (data.getLabels().rank() == 3) {
            if (data.getLabelsMaskArray() == null) {
                evaluation.evalTimeSeries(data.getLabels(), out);
            } else {
                evaluation.evalTimeSeries(data.getLabels(), out, data.getLabelsMaskArray());
            }
        } else {
            evaluation.eval(data.getLabels(), out);
        }
    }
    if (log.isDebugEnabled()) {
        log.debug("Evaluated {} examples ", totalCount);
    }
    return Collections.singletonList(evaluation);
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) DataSet(org.nd4j.linalg.dataset.DataSet) ArrayList(java.util.ArrayList) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork)

Example 35 with DataSet

use of org.nd4j.linalg.dataset.DataSet in project deeplearning4j by deeplearning4j.

the class ScoreExamplesWithKeyFunctionAdapter method call.

@Override
public Iterable<Tuple2<K, Double>> call(Iterator<Tuple2<K, DataSet>> iterator) throws Exception {
    if (!iterator.hasNext()) {
        return Collections.emptyList();
    }
    MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(jsonConfig.getValue()));
    network.init();
    INDArray val = params.value().unsafeDuplication();
    if (val.length() != network.numParams(false))
        throw new IllegalStateException("Network did not have same number of parameters as the broadcast set parameters");
    network.setParameters(val);
    List<Tuple2<K, Double>> ret = new ArrayList<>();
    List<DataSet> collect = new ArrayList<>(batchSize);
    List<K> collectKey = new ArrayList<>(batchSize);
    int totalCount = 0;
    while (iterator.hasNext()) {
        collect.clear();
        collectKey.clear();
        int nExamples = 0;
        while (iterator.hasNext() && nExamples < batchSize) {
            Tuple2<K, DataSet> t2 = iterator.next();
            DataSet ds = t2._2();
            int n = ds.numExamples();
            if (n != 1)
                throw new IllegalStateException("Cannot score examples with one key per data set if " + "data set contains more than 1 example (numExamples: " + n + ")");
            collect.add(ds);
            collectKey.add(t2._1());
            nExamples += n;
        }
        totalCount += nExamples;
        DataSet data = DataSet.merge(collect);
        INDArray scores = network.scoreExamples(data, addRegularization);
        double[] doubleScores = scores.data().asDouble();
        for (int i = 0; i < doubleScores.length; i++) {
            ret.add(new Tuple2<>(collectKey.get(i), doubleScores[i]));
        }
    }
    if (Nd4j.getExecutioner() instanceof GridExecutioner)
        ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
    if (log.isDebugEnabled()) {
        log.debug("Scored {} examples ", totalCount);
    }
    return ret;
}
Also used : DataSet(org.nd4j.linalg.dataset.DataSet) ArrayList(java.util.ArrayList) GridExecutioner(org.nd4j.linalg.api.ops.executioner.GridExecutioner) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Tuple2(scala.Tuple2) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork)

Aggregations

DataSet (org.nd4j.linalg.dataset.DataSet)334 Test (org.junit.Test)226 INDArray (org.nd4j.linalg.api.ndarray.INDArray)194 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)93 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)82 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)79 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)73 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)62 ArrayList (java.util.ArrayList)50 MnistDataSetIterator (org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator)41 ScoreIterationListener (org.deeplearning4j.optimize.listeners.ScoreIterationListener)38 BaseSparkTest (org.deeplearning4j.spark.BaseSparkTest)34 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)32 DenseLayer (org.deeplearning4j.nn.conf.layers.DenseLayer)31 MultiDataSet (org.nd4j.linalg.dataset.MultiDataSet)31 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)25 SequenceRecordReader (org.datavec.api.records.reader.SequenceRecordReader)24 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)24 CSVSequenceRecordReader (org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader)23 ClassPathResource (org.nd4j.linalg.io.ClassPathResource)23