Search in sources :

Example 11 with MultiDataSet

use of org.nd4j.linalg.dataset.api.MultiDataSet in project deeplearning4j by deeplearning4j.

the class ScoreFlatMapFunctionCGMultiDataSetAdapter method call.

@Override
public Iterable<Tuple2<Integer, Double>> call(Iterator<MultiDataSet> dataSetIterator) throws Exception {
    if (!dataSetIterator.hasNext()) {
        return Collections.singletonList(new Tuple2<>(0, 0.0));
    }
    //Does batching where appropriate
    MultiDataSetIterator iter = new IteratorMultiDataSetIterator(dataSetIterator, minibatchSize);
    ComputationGraph network = new ComputationGraph(ComputationGraphConfiguration.fromJson(json));
    network.init();
    //.value() is shared by all executors on single machine -> OK, as params are not changed in score function
    INDArray val = params.value().unsafeDuplication();
    if (val.length() != network.numParams(false))
        throw new IllegalStateException("Network did not have same number of parameters as the broadcast set parameters");
    network.setParams(val);
    List<Tuple2<Integer, Double>> out = new ArrayList<>();
    while (iter.hasNext()) {
        MultiDataSet ds = iter.next();
        double score = network.score(ds, false);
        int numExamples = ds.getFeatures(0).size(0);
        out.add(new Tuple2<>(numExamples, score * numExamples));
    }
    if (Nd4j.getExecutioner() instanceof GridExecutioner)
        ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
    return out;
}
Also used : IteratorMultiDataSetIterator(org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator) ArrayList(java.util.ArrayList) IteratorMultiDataSetIterator(org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) GridExecutioner(org.nd4j.linalg.api.ops.executioner.GridExecutioner) INDArray(org.nd4j.linalg.api.ndarray.INDArray) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) Tuple2(scala.Tuple2) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph)

Example 12 with MultiDataSet

use of org.nd4j.linalg.dataset.api.MultiDataSet in project deeplearning4j by deeplearning4j.

the class ExecuteWorkerMultiDataSetFlatMapAdapter method call.

@Override
public Iterable<R> call(Iterator<MultiDataSet> dataSetIterator) throws Exception {
    WorkerConfiguration dataConfig = worker.getDataConfiguration();
    boolean stats = dataConfig.isCollectTrainingStats();
    StatsCalculationHelper s = (stats ? new StatsCalculationHelper() : null);
    if (stats)
        s.logMethodStartTime();
    if (!dataSetIterator.hasNext()) {
        if (stats)
            s.logReturnTime();
        //Sometimes: no data
        return Collections.emptyList();
    }
    int batchSize = dataConfig.getBatchSizePerWorker();
    final int prefetchCount = dataConfig.getPrefetchNumBatches();
    MultiDataSetIterator batchedIterator = new IteratorMultiDataSetIterator(dataSetIterator, batchSize);
    if (prefetchCount > 0) {
        batchedIterator = new AsyncMultiDataSetIterator(batchedIterator, prefetchCount);
    }
    try {
        if (stats)
            s.logInitialModelBefore();
        ComputationGraph net = worker.getInitialModelGraph();
        if (stats)
            s.logInitialModelAfter();
        int miniBatchCount = 0;
        int maxMinibatches = (dataConfig.getMaxBatchesPerWorker() > 0 ? dataConfig.getMaxBatchesPerWorker() : Integer.MAX_VALUE);
        while (batchedIterator.hasNext() && miniBatchCount++ < maxMinibatches) {
            if (stats)
                s.logNextDataSetBefore();
            MultiDataSet next = batchedIterator.next();
            if (stats)
                s.logNextDataSetAfter(next.getFeatures(0).size(0));
            if (stats) {
                s.logProcessMinibatchBefore();
                Pair<R, SparkTrainingStats> result = worker.processMinibatchWithStats(next, net, !batchedIterator.hasNext());
                s.logProcessMinibatchAfter();
                if (result != null) {
                    //Terminate training immediately
                    s.logReturnTime();
                    SparkTrainingStats workerStats = result.getSecond();
                    SparkTrainingStats returnStats = s.build(workerStats);
                    result.getFirst().setStats(returnStats);
                    return Collections.singletonList(result.getFirst());
                }
            } else {
                R result = worker.processMinibatch(next, net, !batchedIterator.hasNext());
                if (result != null) {
                    //Terminate training immediately
                    return Collections.singletonList(result);
                }
            }
        }
        //For some reason, we didn't return already. Normally this shouldn't happen
        if (stats) {
            s.logReturnTime();
            Pair<R, SparkTrainingStats> pair = worker.getFinalResultWithStats(net);
            pair.getFirst().setStats(s.build(pair.getSecond()));
            return Collections.singletonList(pair.getFirst());
        } else {
            return Collections.singletonList(worker.getFinalResult(net));
        }
    } finally {
        //Make sure we shut down the async thread properly...
        if (batchedIterator instanceof AsyncMultiDataSetIterator) {
            ((AsyncMultiDataSetIterator) batchedIterator).shutdown();
        }
        if (Nd4j.getExecutioner() instanceof GridExecutioner)
            ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
    }
}
Also used : IteratorMultiDataSetIterator(org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator) AsyncMultiDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator) SparkTrainingStats(org.deeplearning4j.spark.api.stats.SparkTrainingStats) IteratorMultiDataSetIterator(org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) AsyncMultiDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator) GridExecutioner(org.nd4j.linalg.api.ops.executioner.GridExecutioner) WorkerConfiguration(org.deeplearning4j.spark.api.WorkerConfiguration) StatsCalculationHelper(org.deeplearning4j.spark.api.stats.StatsCalculationHelper) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph)

Example 13 with MultiDataSet

use of org.nd4j.linalg.dataset.api.MultiDataSet in project deeplearning4j by deeplearning4j.

the class BatchAndExportMultiDataSetsFunction method processList.

private Pair<Integer, List<String>> processList(LinkedList<MultiDataSet> tempList, int partitionIdx, int countBefore, boolean finalExport) throws Exception {
    //Go through the list. If we have enough examples: remove the DataSet objects, merge and export them. Otherwise: do nothing
    int numExamples = 0;
    for (MultiDataSet ds : tempList) {
        numExamples += ds.getFeatures(0).size(0);
    }
    if (tempList.size() == 0 || (numExamples < minibatchSize && !finalExport)) {
        //No op
        return new Pair<>(countBefore, Collections.<String>emptyList());
    }
    List<String> exportPaths = new ArrayList<>();
    int countAfter = countBefore;
    //Batch the required number together
    int countSoFar = 0;
    List<MultiDataSet> tempToMerge = new ArrayList<>();
    while (tempList.size() > 0 && countSoFar != minibatchSize) {
        MultiDataSet next = tempList.removeFirst();
        if (countSoFar + next.getFeatures(0).size(0) <= minibatchSize) {
            //Add the entire DataSet object
            tempToMerge.add(next);
            countSoFar += next.getFeatures(0).size(0);
        } else {
            //Split the DataSet
            List<MultiDataSet> examples = next.asList();
            for (MultiDataSet ds : examples) {
                tempList.addFirst(ds);
            }
        }
    }
    //At this point: we should have the required number of examples in tempToMerge (unless it's a final export)
    MultiDataSet toExport = org.nd4j.linalg.dataset.MultiDataSet.merge(tempToMerge);
    exportPaths.add(export(toExport, partitionIdx, countAfter++));
    return new Pair<>(countAfter, exportPaths);
}
Also used : MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) Pair(org.deeplearning4j.berkeley.Pair)

Example 14 with MultiDataSet

use of org.nd4j.linalg.dataset.api.MultiDataSet in project deeplearning4j by deeplearning4j.

the class TestSparkComputationGraph method testBasic.

@Test
public void testBasic() throws Exception {
    JavaSparkContext sc = this.sc;
    RecordReader rr = new CSVRecordReader(0, ",");
    rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
    MultiDataSetIterator iter = new RecordReaderMultiDataSetIterator.Builder(1).addReader("iris", rr).addInput("iris", 0, 3).addOutputOneHot("iris", 4, 3).build();
    List<MultiDataSet> list = new ArrayList<>(150);
    while (iter.hasNext()) list.add(iter.next());
    ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(0.1).graphBuilder().addInputs("in").addLayer("dense", new DenseLayer.Builder().nIn(4).nOut(2).build(), "in").addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3).build(), "dense").setOutputs("out").pretrain(false).backprop(true).build();
    ComputationGraph cg = new ComputationGraph(config);
    cg.init();
    TrainingMaster tm = new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 10, 1, 0);
    SparkComputationGraph scg = new SparkComputationGraph(sc, cg, tm);
    scg.setListeners(Collections.singleton((IterationListener) new ScoreIterationListener(1)));
    JavaRDD<MultiDataSet> rdd = sc.parallelize(list);
    scg.fitMultiDataSet(rdd);
    //Try: fitting using DataSet
    DataSetIterator iris = new IrisDataSetIterator(1, 150);
    List<DataSet> list2 = new ArrayList<>();
    while (iris.hasNext()) list2.add(iris.next());
    JavaRDD<DataSet> rddDS = sc.parallelize(list2);
    scg.fit(rddDS);
}
Also used : IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSet(org.nd4j.linalg.dataset.DataSet) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) RecordReader(org.datavec.api.records.reader.RecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) FileSplit(org.datavec.api.split.FileSplit) TrainingMaster(org.deeplearning4j.spark.api.TrainingMaster) ParameterAveragingTrainingMaster(org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster) RecordReaderMultiDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) ParameterAveragingTrainingMaster(org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) IterationListener(org.deeplearning4j.optimize.api.IterationListener) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) RecordReaderMultiDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) BaseSparkTest(org.deeplearning4j.spark.BaseSparkTest) Test(org.junit.Test)

Example 15 with MultiDataSet

use of org.nd4j.linalg.dataset.api.MultiDataSet in project deeplearning4j by deeplearning4j.

the class RecordReaderMultiDataSetIteratorTest method testSplittingCSVSequence.

@Test
public void testSplittingCSVSequence() throws Exception {
    //need to manually extract
    for (int i = 0; i < 3; i++) {
        new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive();
        new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive();
        new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive();
    }
    ClassPathResource resource = new ClassPathResource("csvsequence_0.txt");
    String featuresPath = resource.getTempFileFromArchive().getAbsolutePath().replaceAll("0", "%d");
    resource = new ClassPathResource("csvsequencelabels_0.txt");
    String labelsPath = resource.getTempFileFromArchive().getAbsolutePath().replaceAll("0", "%d");
    SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
    featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
    SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false);
    SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ",");
    featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
    MultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("seq1", featureReader2).addSequenceReader("seq2", labelReader2).addInput("seq1", 0, 1).addInput("seq1", 2, 2).addOutputOneHot("seq2", 0, 4).build();
    while (iter.hasNext()) {
        DataSet ds = iter.next();
        INDArray fds = ds.getFeatureMatrix();
        INDArray lds = ds.getLabels();
        MultiDataSet mds = srrmdsi.next();
        assertEquals(2, mds.getFeatures().length);
        assertEquals(1, mds.getLabels().length);
        assertNull(mds.getFeaturesMaskArrays());
        assertNull(mds.getLabelsMaskArrays());
        INDArray[] fmds = mds.getFeatures();
        INDArray[] lmds = mds.getLabels();
        assertNotNull(fmds);
        assertNotNull(lmds);
        for (int i = 0; i < fmds.length; i++) assertNotNull(fmds[i]);
        for (int i = 0; i < lmds.length; i++) assertNotNull(lmds[i]);
        INDArray expIn1 = fds.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 1, true), NDArrayIndex.all());
        INDArray expIn2 = fds.get(NDArrayIndex.all(), NDArrayIndex.interval(2, 2, true), NDArrayIndex.all());
        assertEquals(expIn1, fmds[0]);
        assertEquals(expIn2, fmds[1]);
        assertEquals(lds, lmds[0]);
    }
    assertFalse(srrmdsi.hasNext());
}
Also used : CSVSequenceRecordReader(org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader) SequenceRecordReader(org.datavec.api.records.reader.SequenceRecordReader) DataSet(org.nd4j.linalg.dataset.DataSet) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) NumberedFileInputSplit(org.datavec.api.split.NumberedFileInputSplit) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) CSVSequenceRecordReader(org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader) INDArray(org.nd4j.linalg.api.ndarray.INDArray) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) Test(org.junit.Test)

Aggregations

MultiDataSet (org.nd4j.linalg.dataset.api.MultiDataSet)28 Test (org.junit.Test)12 ClassPathResource (org.nd4j.linalg.io.ClassPathResource)11 INDArray (org.nd4j.linalg.api.ndarray.INDArray)10 MultiDataSetIterator (org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator)10 SequenceRecordReader (org.datavec.api.records.reader.SequenceRecordReader)8 CSVSequenceRecordReader (org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader)8 DataSet (org.nd4j.linalg.dataset.DataSet)8 FileSplit (org.datavec.api.split.FileSplit)7 ImageRecordReader (org.datavec.image.recordreader.ImageRecordReader)6 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)6 RecordReader (org.datavec.api.records.reader.RecordReader)5 CSVRecordReader (org.datavec.api.records.reader.impl.csv.CSVRecordReader)5 NumberedFileInputSplit (org.datavec.api.split.NumberedFileInputSplit)5 ArrayList (java.util.ArrayList)4 RecordMetaData (org.datavec.api.records.metadata.RecordMetaData)4 GridExecutioner (org.nd4j.linalg.api.ops.executioner.GridExecutioner)4 AsyncMultiDataSetIterator (org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator)3 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)3 Random (java.util.Random)2