Search in sources :

Example 21 with DataSetIterator

use of org.nd4j.linalg.dataset.api.iterator.DataSetIterator in project deeplearning4j by deeplearning4j.

the class EvaluationToolsTests method testRocHtml.

@Test
public void testRocHtml() throws Exception {
    DataSetIterator iter = new IrisDataSetIterator(150, 150);
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(4).activation(Activation.TANH).build()).layer(1, new OutputLayer.Builder().nIn(4).nOut(2).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();
    NormalizerStandardize ns = new NormalizerStandardize();
    DataSet ds = iter.next();
    ns.fit(ds);
    ns.transform(ds);
    INDArray newLabels = Nd4j.create(150, 2);
    newLabels.getColumn(0).assign(ds.getLabels().getColumn(0));
    newLabels.getColumn(0).addi(ds.getLabels().getColumn(1));
    newLabels.getColumn(1).assign(ds.getLabels().getColumn(2));
    ds.setLabels(newLabels);
    for (int i = 0; i < 30; i++) {
        net.fit(ds);
    }
    ROC roc = new ROC(20);
    iter.reset();
    INDArray f = ds.getFeatures();
    INDArray l = ds.getLabels();
    INDArray out = net.output(f);
    roc.eval(l, out);
    String str = EvaluationTools.rocChartToHtml(roc);
//        System.out.println(str);
}
Also used : OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSet(org.nd4j.linalg.dataset.api.DataSet) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) INDArray(org.nd4j.linalg.api.ndarray.INDArray) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) NormalizerStandardize(org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) Test(org.junit.Test)

Example 22 with DataSetIterator

use of org.nd4j.linalg.dataset.api.iterator.DataSetIterator in project deeplearning4j by deeplearning4j.

the class RecordReaderMultiDataSetIteratorTest method testImagesRRDMSI.

@Test
public void testImagesRRDMSI() throws Exception {
    File parentDir = Files.createTempDir();
    parentDir.deleteOnExit();
    String str1 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Zico/");
    String str2 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Ziwang_Xu/");
    File f1 = new File(str1);
    File f2 = new File(str2);
    f1.mkdirs();
    f2.mkdirs();
    writeStreamToFile(new File(FilenameUtils.concat(f1.getPath(), "Zico_0001.jpg")), new ClassPathResource("lfwtest/Zico/Zico_0001.jpg").getInputStream());
    writeStreamToFile(new File(FilenameUtils.concat(f2.getPath(), "Ziwang_Xu_0001.jpg")), new ClassPathResource("lfwtest/Ziwang_Xu/Ziwang_Xu_0001.jpg").getInputStream());
    int outputNum = 2;
    Random r = new Random(12345);
    ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
    ImageRecordReader rr1 = new ImageRecordReader(10, 10, 1, labelMaker);
    ImageRecordReader rr1s = new ImageRecordReader(5, 5, 1, labelMaker);
    rr1.initialize(new FileSplit(parentDir));
    rr1s.initialize(new FileSplit(parentDir));
    MultiDataSetIterator trainDataIterator = new RecordReaderMultiDataSetIterator.Builder(1).addReader("rr1", rr1).addReader("rr1s", rr1s).addInput("rr1", 0, 0).addInput("rr1s", 0, 0).addOutputOneHot("rr1s", 1, outputNum).build();
    //Now, do the same thing with ImageRecordReader, and check we get the same results:
    ImageRecordReader rr1_b = new ImageRecordReader(10, 10, 1, labelMaker);
    ImageRecordReader rr1s_b = new ImageRecordReader(5, 5, 1, labelMaker);
    rr1_b.initialize(new FileSplit(parentDir));
    rr1s_b.initialize(new FileSplit(parentDir));
    DataSetIterator dsi1 = new RecordReaderDataSetIterator(rr1_b, 1, 1, 2);
    DataSetIterator dsi2 = new RecordReaderDataSetIterator(rr1s_b, 1, 1, 2);
    for (int i = 0; i < 2; i++) {
        MultiDataSet mds = trainDataIterator.next();
        DataSet d1 = dsi1.next();
        DataSet d2 = dsi2.next();
        assertEquals(d1.getFeatureMatrix(), mds.getFeatures(0));
        assertEquals(d2.getFeatureMatrix(), mds.getFeatures(1));
        assertEquals(d1.getLabels(), mds.getLabels(0));
    }
}
Also used : DataSet(org.nd4j.linalg.dataset.DataSet) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) FileSplit(org.datavec.api.split.FileSplit) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) ParentPathLabelGenerator(org.datavec.api.io.labels.ParentPathLabelGenerator) Random(java.util.Random) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) ImageRecordReader(org.datavec.image.recordreader.ImageRecordReader) Test(org.junit.Test)

Example 23 with DataSetIterator

use of org.nd4j.linalg.dataset.api.iterator.DataSetIterator in project deeplearning4j by deeplearning4j.

the class RecordReaderMultiDataSetIteratorTest method testImagesRRDMSI_Batched.

@Test
public void testImagesRRDMSI_Batched() throws Exception {
    File parentDir = Files.createTempDir();
    parentDir.deleteOnExit();
    String str1 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Zico/");
    String str2 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Ziwang_Xu/");
    File f1 = new File(str1);
    File f2 = new File(str2);
    f1.mkdirs();
    f2.mkdirs();
    writeStreamToFile(new File(FilenameUtils.concat(f1.getPath(), "Zico_0001.jpg")), new ClassPathResource("lfwtest/Zico/Zico_0001.jpg").getInputStream());
    writeStreamToFile(new File(FilenameUtils.concat(f2.getPath(), "Ziwang_Xu_0001.jpg")), new ClassPathResource("lfwtest/Ziwang_Xu/Ziwang_Xu_0001.jpg").getInputStream());
    int outputNum = 2;
    ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
    ImageRecordReader rr1 = new ImageRecordReader(10, 10, 1, labelMaker);
    ImageRecordReader rr1s = new ImageRecordReader(5, 5, 1, labelMaker);
    rr1.initialize(new FileSplit(parentDir));
    rr1s.initialize(new FileSplit(parentDir));
    MultiDataSetIterator trainDataIterator = new RecordReaderMultiDataSetIterator.Builder(2).addReader("rr1", rr1).addReader("rr1s", rr1s).addInput("rr1", 0, 0).addInput("rr1s", 0, 0).addOutputOneHot("rr1s", 1, outputNum).build();
    //Now, do the same thing with ImageRecordReader, and check we get the same results:
    ImageRecordReader rr1_b = new ImageRecordReader(10, 10, 1, labelMaker);
    ImageRecordReader rr1s_b = new ImageRecordReader(5, 5, 1, labelMaker);
    rr1_b.initialize(new FileSplit(parentDir));
    rr1s_b.initialize(new FileSplit(parentDir));
    DataSetIterator dsi1 = new RecordReaderDataSetIterator(rr1_b, 2, 1, 2);
    DataSetIterator dsi2 = new RecordReaderDataSetIterator(rr1s_b, 2, 1, 2);
    MultiDataSet mds = trainDataIterator.next();
    DataSet d1 = dsi1.next();
    DataSet d2 = dsi2.next();
    assertEquals(d1.getFeatureMatrix(), mds.getFeatures(0));
    assertEquals(d2.getFeatureMatrix(), mds.getFeatures(1));
    assertEquals(d1.getLabels(), mds.getLabels(0));
}
Also used : DataSet(org.nd4j.linalg.dataset.DataSet) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) FileSplit(org.datavec.api.split.FileSplit) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) ParentPathLabelGenerator(org.datavec.api.io.labels.ParentPathLabelGenerator) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) ImageRecordReader(org.datavec.image.recordreader.ImageRecordReader) Test(org.junit.Test)

Example 24 with DataSetIterator

use of org.nd4j.linalg.dataset.api.iterator.DataSetIterator in project deeplearning4j by deeplearning4j.

the class ScoreFlatMapFunctionCGDataSetAdapter method call.

@Override
public Iterable<Tuple2<Integer, Double>> call(Iterator<DataSet> dataSetIterator) throws Exception {
    if (!dataSetIterator.hasNext()) {
        return Collections.singletonList(new Tuple2<>(0, 0.0));
    }
    //Does batching where appropriate
    DataSetIterator iter = new IteratorDataSetIterator(dataSetIterator, minibatchSize);
    ComputationGraph network = new ComputationGraph(ComputationGraphConfiguration.fromJson(json));
    network.init();
    //.value() is shared by all executors on single machine -> OK, as params are not changed in score function
    INDArray val = params.value().unsafeDuplication();
    if (val.length() != network.numParams(false))
        throw new IllegalStateException("Network did not have same number of parameters as the broadcast set parameters");
    network.setParams(val);
    List<Tuple2<Integer, Double>> out = new ArrayList<>();
    while (iter.hasNext()) {
        DataSet ds = iter.next();
        double score = network.score(ds, false);
        int numExamples = ds.getFeatureMatrix().size(0);
        out.add(new Tuple2<>(numExamples, score * numExamples));
    }
    if (Nd4j.getExecutioner() instanceof GridExecutioner)
        ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
    return out;
}
Also used : DataSet(org.nd4j.linalg.dataset.DataSet) ArrayList(java.util.ArrayList) GridExecutioner(org.nd4j.linalg.api.ops.executioner.GridExecutioner) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Tuple2(scala.Tuple2) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) IteratorDataSetIterator(org.deeplearning4j.datasets.iterator.IteratorDataSetIterator) IteratorDataSetIterator(org.deeplearning4j.datasets.iterator.IteratorDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator)

Example 25 with DataSetIterator

use of org.nd4j.linalg.dataset.api.iterator.DataSetIterator in project deeplearning4j by deeplearning4j.

the class ParallelWrapper method fit.

/**
     * This method takes DataSetIterator, and starts training over it by scheduling DataSets to different executors
     *
     * @param source
     */
public synchronized void fit(@NonNull DataSetIterator source) {
    stopFit.set(false);
    if (zoo == null) {
        zoo = new Trainer[workers];
        for (int cnt = 0; cnt < workers; cnt++) {
            zoo[cnt] = new Trainer(cnt, model, Nd4j.getAffinityManager().getDeviceForCurrentThread());
            // if if we're using MQ here - we'd like
            if (isMQ)
                Nd4j.getAffinityManager().attachThreadToDevice(zoo[cnt], cnt % Nd4j.getAffinityManager().getNumberOfDevices());
            zoo[cnt].setUncaughtExceptionHandler(handler);
            zoo[cnt].start();
        }
    }
    source.reset();
    DataSetIterator iterator;
    if (prefetchSize > 0 && source.asyncSupported()) {
        if (isMQ) {
            if (workers % Nd4j.getAffinityManager().getNumberOfDevices() != 0)
                log.warn("Number of workers [{}] isn't optimal for available devices [{}]", workers, Nd4j.getAffinityManager().getNumberOfDevices());
            MagicQueue queue = new MagicQueue.Builder().setCapacityPerFlow(8).setMode(MagicQueue.Mode.SEQUENTIAL).setNumberOfBuckets(Nd4j.getAffinityManager().getNumberOfDevices()).build();
            iterator = new AsyncDataSetIterator(source, prefetchSize, queue);
        } else
            iterator = new AsyncDataSetIterator(source, prefetchSize);
    } else
        iterator = source;
    AtomicInteger locker = new AtomicInteger(0);
    int whiles = 0;
    while (iterator.hasNext() && !stopFit.get()) {
        whiles++;
        DataSet dataSet = iterator.next();
        if (dataSet == null)
            throw new ND4JIllegalStateException("You can't have NULL as DataSet");
        /*
             now dataSet should be dispatched to next free workers, until all workers are busy. And then we should block till all finished.
            */
        int pos = locker.getAndIncrement();
        if (zoo == null)
            throw new IllegalStateException("ParallelWrapper.shutdown() has been called too early and will fail from this point forward.");
        zoo[pos].feedDataSet(dataSet);
        /*
                if all workers are dispatched now, join till all are finished
            */
        if (pos + 1 == workers || !iterator.hasNext()) {
            iterationsCounter.incrementAndGet();
            for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) {
                try {
                    zoo[cnt].waitTillRunning();
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
            Nd4j.getMemoryManager().invokeGcOccasionally();
            /*
                    average model, and propagate it to whole
                */
            if (iterationsCounter.get() % averagingFrequency == 0 && pos + 1 == workers) {
                double score = getScore(locker);
                // averaging updaters state
                if (model instanceof MultiLayerNetwork) {
                    if (averageUpdaters) {
                        Updater updater = ((MultiLayerNetwork) model).getUpdater();
                        int batchSize = 0;
                        if (updater != null && updater.getStateViewArray() != null) {
                            if (!legacyAveraging || Nd4j.getAffinityManager().getNumberOfDevices() == 1) {
                                List<INDArray> updaters = new ArrayList<>();
                                for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) {
                                    MultiLayerNetwork workerModel = (MultiLayerNetwork) zoo[cnt].getModel();
                                    updaters.add(workerModel.getUpdater().getStateViewArray());
                                    batchSize += workerModel.batchSize();
                                }
                                Nd4j.averageAndPropagate(updater.getStateViewArray(), updaters);
                            } else {
                                INDArray state = Nd4j.zeros(updater.getStateViewArray().shape());
                                int cnt = 0;
                                for (; cnt < workers && cnt < locker.get(); cnt++) {
                                    MultiLayerNetwork workerModel = (MultiLayerNetwork) zoo[cnt].getModel();
                                    state.addi(workerModel.getUpdater().getStateViewArray().dup());
                                    batchSize += workerModel.batchSize();
                                }
                                state.divi(cnt);
                                updater.setStateViewArray((MultiLayerNetwork) model, state, false);
                            }
                        }
                    }
                    ((MultiLayerNetwork) model).setScore(score);
                } else if (model instanceof ComputationGraph) {
                    averageUpdatersState(locker, score);
                }
                if (legacyAveraging && Nd4j.getAffinityManager().getNumberOfDevices() > 1) {
                    for (int cnt = 0; cnt < workers; cnt++) {
                        zoo[cnt].updateModel(model);
                    }
                }
            }
            locker.set(0);
        }
    }
    // sanity checks, or the dataset may never average
    if (!wasAveraged)
        log.warn("Parameters were never averaged on current fit(). Ratios of batch size, num workers, and averaging frequency may be responsible.");
    //            throw new IllegalStateException("Parameters were never averaged. Please check batch size ratios, number of workers, and your averaging frequency.");
    log.debug("Iterations passed: {}", iterationsCounter.get());
//        iterationsCounter.set(0);
}
Also used : ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) DataSet(org.nd4j.linalg.dataset.api.DataSet) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) INDArray(org.nd4j.linalg.api.ndarray.INDArray) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) ComputationGraphUpdater(org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater) Updater(org.deeplearning4j.nn.api.Updater) AsyncDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncDataSetIterator) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) AsyncDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncDataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) AsyncMultiDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator)

Aggregations

DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)147 Test (org.junit.Test)133 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)90 DataSet (org.nd4j.linalg.dataset.DataSet)79 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)70 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)63 INDArray (org.nd4j.linalg.api.ndarray.INDArray)61 MnistDataSetIterator (org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator)53 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)49 ScoreIterationListener (org.deeplearning4j.optimize.listeners.ScoreIterationListener)43 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)30 DenseLayer (org.deeplearning4j.nn.conf.layers.DenseLayer)24 MultiDataSetIterator (org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator)21 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)19 InMemoryModelSaver (org.deeplearning4j.earlystopping.saver.InMemoryModelSaver)17 MaxEpochsTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition)17 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)17 ListDataSetIterator (org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator)16 BaseSparkTest (org.deeplearning4j.spark.BaseSparkTest)16 MaxTimeIterationTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition)14