Search in sources :

Example 1 with DataSet

use of org.nd4j.linalg.dataset.api.DataSet in project deeplearning4j by deeplearning4j.

the class TestCnnSentenceDataSetIterator method testSentenceIterator.

@Test
public void testSentenceIterator() throws Exception {
    WordVectors w2v = WordVectorSerializer.readWord2VecModel(new ClassPathResource("word2vec/googleload/sample_vec.bin").getFile());
    int vectorSize = w2v.lookupTable().layerSize();
    //        Collection<String> words = w2v.lookupTable().getVocabCache().words();
    //        for(String s : words){
    //            System.out.println(s);
    //        }
    List<String> sentences = new ArrayList<>();
    //First word: all present
    sentences.add("these balance Database model");
    sentences.add("into same THISWORDDOESNTEXIST are");
    int maxLength = 4;
    List<String> s1 = Arrays.asList("these", "balance", "Database", "model");
    List<String> s2 = Arrays.asList("into", "same", "are");
    List<String> labelsForSentences = Arrays.asList("Positive", "Negative");
    //Order of labels: alphabetic. Positive -> [0,1]
    INDArray expLabels = Nd4j.create(new double[][] { { 0, 1 }, { 1, 0 } });
    boolean[] alongHeightVals = new boolean[] { true, false };
    for (boolean alongHeight : alongHeightVals) {
        INDArray expectedFeatures;
        if (alongHeight) {
            expectedFeatures = Nd4j.create(2, 1, maxLength, vectorSize);
        } else {
            expectedFeatures = Nd4j.create(2, 1, vectorSize, maxLength);
        }
        INDArray expectedFeatureMask = Nd4j.create(new double[][] { { 1, 1, 1, 1 }, { 1, 1, 1, 0 } });
        for (int i = 0; i < 4; i++) {
            if (alongHeight) {
                expectedFeatures.get(NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.point(i), NDArrayIndex.all()).assign(w2v.getWordVectorMatrix(s1.get(i)));
            } else {
                expectedFeatures.get(NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(i)).assign(w2v.getWordVectorMatrix(s1.get(i)));
            }
        }
        for (int i = 0; i < 3; i++) {
            if (alongHeight) {
                expectedFeatures.get(NDArrayIndex.point(1), NDArrayIndex.point(0), NDArrayIndex.point(i), NDArrayIndex.all()).assign(w2v.getWordVectorMatrix(s2.get(i)));
            } else {
                expectedFeatures.get(NDArrayIndex.point(1), NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(i)).assign(w2v.getWordVectorMatrix(s2.get(i)));
            }
        }
        LabeledSentenceProvider p = new CollectionLabeledSentenceProvider(sentences, labelsForSentences, null);
        CnnSentenceDataSetIterator dsi = new CnnSentenceDataSetIterator.Builder().sentenceProvider(p).wordVectors(w2v).maxSentenceLength(256).minibatchSize(32).sentencesAlongHeight(alongHeight).build();
        //            System.out.println("alongHeight = " + alongHeight);
        DataSet ds = dsi.next();
        assertArrayEquals(expectedFeatures.shape(), ds.getFeatures().shape());
        assertEquals(expectedFeatures, ds.getFeatures());
        assertEquals(expLabels, ds.getLabels());
        assertEquals(expectedFeatureMask, ds.getFeaturesMaskArray());
        assertNull(ds.getLabelsMaskArray());
        INDArray s1F = dsi.loadSingleSentence(sentences.get(0));
        INDArray s2F = dsi.loadSingleSentence(sentences.get(1));
        INDArray sub1 = ds.getFeatures().get(NDArrayIndex.interval(0, 0, true), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all());
        INDArray sub2;
        if (alongHeight) {
            sub2 = ds.getFeatures().get(NDArrayIndex.interval(1, 1, true), NDArrayIndex.all(), NDArrayIndex.interval(0, 3), NDArrayIndex.all());
        } else {
            sub2 = ds.getFeatures().get(NDArrayIndex.interval(1, 1, true), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 3));
        }
        assertArrayEquals(sub1.shape(), s1F.shape());
        assertArrayEquals(sub2.shape(), s2F.shape());
        assertEquals(sub1, s1F);
        assertEquals(sub2, s2F);
    }
}
Also used : DataSet(org.nd4j.linalg.dataset.api.DataSet) ArrayList(java.util.ArrayList) CollectionLabeledSentenceProvider(org.deeplearning4j.iterator.provider.CollectionLabeledSentenceProvider) ClassPathResource(org.datavec.api.util.ClassPathResource) INDArray(org.nd4j.linalg.api.ndarray.INDArray) CollectionLabeledSentenceProvider(org.deeplearning4j.iterator.provider.CollectionLabeledSentenceProvider) WordVectors(org.deeplearning4j.models.embeddings.wordvectors.WordVectors) Test(org.junit.Test)

Example 2 with DataSet

use of org.nd4j.linalg.dataset.api.DataSet in project deeplearning4j by deeplearning4j.

the class ParallelWrapper method fit.

/**
     * This method takes DataSetIterator, and starts training over it by scheduling DataSets to different executors
     *
     * @param source
     */
public synchronized void fit(@NonNull DataSetIterator source) {
    stopFit.set(false);
    if (zoo == null) {
        zoo = new Trainer[workers];
        for (int cnt = 0; cnt < workers; cnt++) {
            zoo[cnt] = new Trainer(cnt, model, Nd4j.getAffinityManager().getDeviceForCurrentThread());
            // if if we're using MQ here - we'd like
            if (isMQ)
                Nd4j.getAffinityManager().attachThreadToDevice(zoo[cnt], cnt % Nd4j.getAffinityManager().getNumberOfDevices());
            zoo[cnt].setUncaughtExceptionHandler(handler);
            zoo[cnt].start();
        }
    }
    source.reset();
    DataSetIterator iterator;
    if (prefetchSize > 0 && source.asyncSupported()) {
        if (isMQ) {
            if (workers % Nd4j.getAffinityManager().getNumberOfDevices() != 0)
                log.warn("Number of workers [{}] isn't optimal for available devices [{}]", workers, Nd4j.getAffinityManager().getNumberOfDevices());
            MagicQueue queue = new MagicQueue.Builder().setCapacityPerFlow(8).setMode(MagicQueue.Mode.SEQUENTIAL).setNumberOfBuckets(Nd4j.getAffinityManager().getNumberOfDevices()).build();
            iterator = new AsyncDataSetIterator(source, prefetchSize, queue);
        } else
            iterator = new AsyncDataSetIterator(source, prefetchSize);
    } else
        iterator = source;
    AtomicInteger locker = new AtomicInteger(0);
    int whiles = 0;
    while (iterator.hasNext() && !stopFit.get()) {
        whiles++;
        DataSet dataSet = iterator.next();
        if (dataSet == null)
            throw new ND4JIllegalStateException("You can't have NULL as DataSet");
        /*
             now dataSet should be dispatched to next free workers, until all workers are busy. And then we should block till all finished.
            */
        int pos = locker.getAndIncrement();
        if (zoo == null)
            throw new IllegalStateException("ParallelWrapper.shutdown() has been called too early and will fail from this point forward.");
        zoo[pos].feedDataSet(dataSet);
        /*
                if all workers are dispatched now, join till all are finished
            */
        if (pos + 1 == workers || !iterator.hasNext()) {
            iterationsCounter.incrementAndGet();
            for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) {
                try {
                    zoo[cnt].waitTillRunning();
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
            Nd4j.getMemoryManager().invokeGcOccasionally();
            /*
                    average model, and propagate it to whole
                */
            if (iterationsCounter.get() % averagingFrequency == 0 && pos + 1 == workers) {
                double score = getScore(locker);
                // averaging updaters state
                if (model instanceof MultiLayerNetwork) {
                    if (averageUpdaters) {
                        Updater updater = ((MultiLayerNetwork) model).getUpdater();
                        int batchSize = 0;
                        if (updater != null && updater.getStateViewArray() != null) {
                            if (!legacyAveraging || Nd4j.getAffinityManager().getNumberOfDevices() == 1) {
                                List<INDArray> updaters = new ArrayList<>();
                                for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) {
                                    MultiLayerNetwork workerModel = (MultiLayerNetwork) zoo[cnt].getModel();
                                    updaters.add(workerModel.getUpdater().getStateViewArray());
                                    batchSize += workerModel.batchSize();
                                }
                                Nd4j.averageAndPropagate(updater.getStateViewArray(), updaters);
                            } else {
                                INDArray state = Nd4j.zeros(updater.getStateViewArray().shape());
                                int cnt = 0;
                                for (; cnt < workers && cnt < locker.get(); cnt++) {
                                    MultiLayerNetwork workerModel = (MultiLayerNetwork) zoo[cnt].getModel();
                                    state.addi(workerModel.getUpdater().getStateViewArray().dup());
                                    batchSize += workerModel.batchSize();
                                }
                                state.divi(cnt);
                                updater.setStateViewArray((MultiLayerNetwork) model, state, false);
                            }
                        }
                    }
                    ((MultiLayerNetwork) model).setScore(score);
                } else if (model instanceof ComputationGraph) {
                    averageUpdatersState(locker, score);
                }
                if (legacyAveraging && Nd4j.getAffinityManager().getNumberOfDevices() > 1) {
                    for (int cnt = 0; cnt < workers; cnt++) {
                        zoo[cnt].updateModel(model);
                    }
                }
            }
            locker.set(0);
        }
    }
    // sanity checks, or the dataset may never average
    if (!wasAveraged)
        log.warn("Parameters were never averaged on current fit(). Ratios of batch size, num workers, and averaging frequency may be responsible.");
    //            throw new IllegalStateException("Parameters were never averaged. Please check batch size ratios, number of workers, and your averaging frequency.");
    log.debug("Iterations passed: {}", iterationsCounter.get());
//        iterationsCounter.set(0);
}
Also used : ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) DataSet(org.nd4j.linalg.dataset.api.DataSet) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) INDArray(org.nd4j.linalg.api.ndarray.INDArray) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) ComputationGraphUpdater(org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater) Updater(org.deeplearning4j.nn.api.Updater) AsyncDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncDataSetIterator) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) AsyncDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncDataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) AsyncMultiDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator)

Example 3 with DataSet

use of org.nd4j.linalg.dataset.api.DataSet in project deeplearning4j by deeplearning4j.

the class TestRecordReaders method testClassIndexOutsideOfRangeRRMDSI_MultipleReaders.

@Test
public void testClassIndexOutsideOfRangeRRMDSI_MultipleReaders() {
    Collection<Collection<Collection<Writable>>> c1 = new ArrayList<>();
    Collection<Collection<Writable>> seq1 = new ArrayList<>();
    seq1.add(Arrays.<Writable>asList(new DoubleWritable(0.0)));
    seq1.add(Arrays.<Writable>asList(new DoubleWritable(0.0)));
    c1.add(seq1);
    Collection<Collection<Writable>> seq2 = new ArrayList<>();
    seq2.add(Arrays.<Writable>asList(new DoubleWritable(0.0)));
    seq2.add(Arrays.<Writable>asList(new DoubleWritable(0.0)));
    c1.add(seq2);
    Collection<Collection<Collection<Writable>>> c2 = new ArrayList<>();
    Collection<Collection<Writable>> seq1a = new ArrayList<>();
    seq1a.add(Arrays.<Writable>asList(new IntWritable(0)));
    seq1a.add(Arrays.<Writable>asList(new IntWritable(1)));
    c2.add(seq1a);
    Collection<Collection<Writable>> seq2a = new ArrayList<>();
    seq2a.add(Arrays.<Writable>asList(new IntWritable(0)));
    seq2a.add(Arrays.<Writable>asList(new IntWritable(2)));
    c2.add(seq2a);
    CollectionSequenceRecordReader csrr = new CollectionSequenceRecordReader(c1);
    CollectionSequenceRecordReader csrrLabels = new CollectionSequenceRecordReader(c2);
    DataSetIterator dsi = new SequenceRecordReaderDataSetIterator(csrr, csrrLabels, 2, 2);
    try {
        DataSet ds = dsi.next();
        fail("Expected exception");
    } catch (DL4JException e) {
        System.out.println("testClassIndexOutsideOfRangeRRMDSI_MultipleReaders(): " + e.getMessage());
    } catch (Exception e) {
        e.printStackTrace();
        fail();
    }
}
Also used : SequenceRecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator) DataSet(org.nd4j.linalg.dataset.api.DataSet) ArrayList(java.util.ArrayList) CollectionSequenceRecordReader(org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader) IntWritable(org.datavec.api.writable.IntWritable) Writable(org.datavec.api.writable.Writable) DoubleWritable(org.datavec.api.writable.DoubleWritable) DoubleWritable(org.datavec.api.writable.DoubleWritable) DL4JException(org.deeplearning4j.exception.DL4JException) DL4JException(org.deeplearning4j.exception.DL4JException) Collection(java.util.Collection) IntWritable(org.datavec.api.writable.IntWritable) SequenceRecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) RecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator) Test(org.junit.Test)

Example 4 with DataSet

use of org.nd4j.linalg.dataset.api.DataSet in project deeplearning4j by deeplearning4j.

the class TestRecordReaders method testClassIndexOutsideOfRangeRRDSI.

@Test
public void testClassIndexOutsideOfRangeRRDSI() {
    Collection<Collection<Writable>> c = new ArrayList<>();
    c.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new IntWritable(0)));
    c.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new IntWritable(2)));
    CollectionRecordReader crr = new CollectionRecordReader(c);
    RecordReaderDataSetIterator iter = new RecordReaderDataSetIterator(crr, 2, 1, 2);
    try {
        DataSet ds = iter.next();
        fail("Expected exception");
    } catch (DL4JException e) {
        System.out.println("testClassIndexOutsideOfRange(): " + e.getMessage());
    } catch (Exception e) {
        e.printStackTrace();
        fail();
    }
}
Also used : DataSet(org.nd4j.linalg.dataset.api.DataSet) ArrayList(java.util.ArrayList) Collection(java.util.Collection) CollectionRecordReader(org.datavec.api.records.reader.impl.collection.CollectionRecordReader) DoubleWritable(org.datavec.api.writable.DoubleWritable) SequenceRecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator) RecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator) DL4JException(org.deeplearning4j.exception.DL4JException) IntWritable(org.datavec.api.writable.IntWritable) DL4JException(org.deeplearning4j.exception.DL4JException) Test(org.junit.Test)

Example 5 with DataSet

use of org.nd4j.linalg.dataset.api.DataSet in project deeplearning4j by deeplearning4j.

the class ROCTest method RocEvalSanityCheck.

@Test
public void RocEvalSanityCheck() {
    DataSetIterator iter = new IrisDataSetIterator(150, 150);
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(4).activation(Activation.TANH).build()).layer(1, new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();
    NormalizerStandardize ns = new NormalizerStandardize();
    DataSet ds = iter.next();
    ns.fit(ds);
    ns.transform(ds);
    iter.setPreProcessor(ns);
    for (int i = 0; i < 30; i++) {
        net.fit(ds);
    }
    ROCMultiClass roc = net.evaluateROCMultiClass(iter, 32);
    INDArray f = ds.getFeatures();
    INDArray l = ds.getLabels();
    INDArray out = net.output(f);
    ROCMultiClass manual = new ROCMultiClass(32);
    manual.eval(l, out);
    for (int i = 0; i < 3; i++) {
        assertEquals(manual.calculateAUC(i), roc.calculateAUC(i), 1e-6);
        double[][] rocCurve = roc.getResultsAsArray(i);
        double[][] rocManual = manual.getResultsAsArray(i);
        assertArrayEquals(rocCurve[0], rocManual[0], 1e-6);
        assertArrayEquals(rocCurve[1], rocManual[1], 1e-6);
    }
}
Also used : OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSet(org.nd4j.linalg.dataset.api.DataSet) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) INDArray(org.nd4j.linalg.api.ndarray.INDArray) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) NormalizerStandardize(org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) Test(org.junit.Test)

Aggregations

DataSet (org.nd4j.linalg.dataset.api.DataSet)11 Test (org.junit.Test)9 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)9 INDArray (org.nd4j.linalg.api.ndarray.INDArray)8 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)5 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)5 ArrayList (java.util.ArrayList)4 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)4 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)4 DenseLayer (org.deeplearning4j.nn.conf.layers.DenseLayer)4 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)4 Collection (java.util.Collection)3 DoubleWritable (org.datavec.api.writable.DoubleWritable)3 IntWritable (org.datavec.api.writable.IntWritable)3 RecordReaderDataSetIterator (org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator)3 SequenceRecordReaderDataSetIterator (org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator)3 DL4JException (org.deeplearning4j.exception.DL4JException)3 NormalizerStandardize (org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize)3 CollectionSequenceRecordReader (org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader)2 Writable (org.datavec.api.writable.Writable)2