Search in sources :

Example 6 with MultiDataSet

use of org.nd4j.linalg.dataset.api.MultiDataSet in project deeplearning4j by deeplearning4j.

the class RecordReaderMultiDataSetIteratorTest method testSplittingCSV.

@Test
public void testSplittingCSV() throws Exception {
    //Here's the idea: take Iris, and split it up into 2 inputs and 2 output arrays
    //Inputs: columns 0 and 1-2
    //Outputs: columns 3, and 4->OneHot
    //need to manually extract
    RecordReader rr = new CSVRecordReader(0, ",");
    rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
    RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(rr, 10, 4, 3);
    RecordReader rr2 = new CSVRecordReader(0, ",");
    rr2.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
    MultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10).addReader("reader", rr2).addInput("reader", 0, 0).addInput("reader", 1, 2).addOutput("reader", 3, 3).addOutputOneHot("reader", 4, 3).build();
    while (rrdsi.hasNext()) {
        DataSet ds = rrdsi.next();
        INDArray fds = ds.getFeatureMatrix();
        INDArray lds = ds.getLabels();
        MultiDataSet mds = rrmdsi.next();
        assertEquals(2, mds.getFeatures().length);
        assertEquals(2, mds.getLabels().length);
        assertNull(mds.getFeaturesMaskArrays());
        assertNull(mds.getLabelsMaskArrays());
        INDArray[] fmds = mds.getFeatures();
        INDArray[] lmds = mds.getLabels();
        assertNotNull(fmds);
        assertNotNull(lmds);
        for (int i = 0; i < fmds.length; i++) assertNotNull(fmds[i]);
        for (int i = 0; i < lmds.length; i++) assertNotNull(lmds[i]);
        //Get the subsets of the original iris data
        INDArray expIn1 = fds.get(NDArrayIndex.all(), NDArrayIndex.point(0));
        INDArray expIn2 = fds.get(NDArrayIndex.all(), NDArrayIndex.interval(1, 2, true));
        INDArray expOut1 = fds.get(NDArrayIndex.all(), NDArrayIndex.point(3));
        INDArray expOut2 = lds;
        assertEquals(expIn1, fmds[0]);
        assertEquals(expIn2, fmds[1]);
        assertEquals(expOut1, lmds[0]);
        assertEquals(expOut2, lmds[1]);
    }
    assertFalse(rrmdsi.hasNext());
}
Also used : MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) INDArray(org.nd4j.linalg.api.ndarray.INDArray) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) DataSet(org.nd4j.linalg.dataset.DataSet) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) RecordReader(org.datavec.api.records.reader.RecordReader) ImageRecordReader(org.datavec.image.recordreader.ImageRecordReader) CSVSequenceRecordReader(org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) SequenceRecordReader(org.datavec.api.records.reader.SequenceRecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) FileSplit(org.datavec.api.split.FileSplit) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Test(org.junit.Test)

Example 7 with MultiDataSet

use of org.nd4j.linalg.dataset.api.MultiDataSet in project deeplearning4j by deeplearning4j.

the class RecordReaderMultiDataSetIteratorTest method testImagesRRDMSI_Batched.

@Test
public void testImagesRRDMSI_Batched() throws Exception {
    File parentDir = Files.createTempDir();
    parentDir.deleteOnExit();
    String str1 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Zico/");
    String str2 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Ziwang_Xu/");
    File f1 = new File(str1);
    File f2 = new File(str2);
    f1.mkdirs();
    f2.mkdirs();
    writeStreamToFile(new File(FilenameUtils.concat(f1.getPath(), "Zico_0001.jpg")), new ClassPathResource("lfwtest/Zico/Zico_0001.jpg").getInputStream());
    writeStreamToFile(new File(FilenameUtils.concat(f2.getPath(), "Ziwang_Xu_0001.jpg")), new ClassPathResource("lfwtest/Ziwang_Xu/Ziwang_Xu_0001.jpg").getInputStream());
    int outputNum = 2;
    ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
    ImageRecordReader rr1 = new ImageRecordReader(10, 10, 1, labelMaker);
    ImageRecordReader rr1s = new ImageRecordReader(5, 5, 1, labelMaker);
    rr1.initialize(new FileSplit(parentDir));
    rr1s.initialize(new FileSplit(parentDir));
    MultiDataSetIterator trainDataIterator = new RecordReaderMultiDataSetIterator.Builder(2).addReader("rr1", rr1).addReader("rr1s", rr1s).addInput("rr1", 0, 0).addInput("rr1s", 0, 0).addOutputOneHot("rr1s", 1, outputNum).build();
    //Now, do the same thing with ImageRecordReader, and check we get the same results:
    ImageRecordReader rr1_b = new ImageRecordReader(10, 10, 1, labelMaker);
    ImageRecordReader rr1s_b = new ImageRecordReader(5, 5, 1, labelMaker);
    rr1_b.initialize(new FileSplit(parentDir));
    rr1s_b.initialize(new FileSplit(parentDir));
    DataSetIterator dsi1 = new RecordReaderDataSetIterator(rr1_b, 2, 1, 2);
    DataSetIterator dsi2 = new RecordReaderDataSetIterator(rr1s_b, 2, 1, 2);
    MultiDataSet mds = trainDataIterator.next();
    DataSet d1 = dsi1.next();
    DataSet d2 = dsi2.next();
    assertEquals(d1.getFeatureMatrix(), mds.getFeatures(0));
    assertEquals(d2.getFeatureMatrix(), mds.getFeatures(1));
    assertEquals(d1.getLabels(), mds.getLabels(0));
}
Also used : DataSet(org.nd4j.linalg.dataset.DataSet) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) FileSplit(org.datavec.api.split.FileSplit) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) ParentPathLabelGenerator(org.datavec.api.io.labels.ParentPathLabelGenerator) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) ImageRecordReader(org.datavec.image.recordreader.ImageRecordReader) Test(org.junit.Test)

Example 8 with MultiDataSet

use of org.nd4j.linalg.dataset.api.MultiDataSet in project deeplearning4j by deeplearning4j.

the class ParameterAveragingTrainingMaster method executeTraining.

@Override
public void executeTraining(SparkComputationGraph graph, JavaRDD<DataSet> trainingData) {
    if (numWorkers == null)
        numWorkers = graph.getSparkContext().defaultParallelism();
    JavaRDD<MultiDataSet> mdsTrainingData = trainingData.map(new DataSetToMultiDataSetFn());
    executeTrainingMDS(graph, mdsTrainingData);
}
Also used : MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) DataSetToMultiDataSetFn(org.deeplearning4j.spark.impl.graph.dataset.DataSetToMultiDataSetFn)

Example 9 with MultiDataSet

use of org.nd4j.linalg.dataset.api.MultiDataSet in project deeplearning4j by deeplearning4j.

the class TestExport method testBatchAndExportMultiDataSetsFunction.

@Test
public void testBatchAndExportMultiDataSetsFunction() throws Exception {
    String baseDir = System.getProperty("java.io.tmpdir");
    baseDir = FilenameUtils.concat(baseDir, "dl4j_spark_testBatchAndExportMDS/");
    baseDir = baseDir.replaceAll("\\\\", "/");
    File f = new File(baseDir);
    if (f.exists())
        FileUtils.deleteDirectory(f);
    f.mkdir();
    f.deleteOnExit();
    int minibatchSize = 5;
    int nIn = 4;
    int nOut = 3;
    List<MultiDataSet> dataSets = new ArrayList<>();
    //Larger than minibatch size -> tests splitting
    dataSets.add(new org.nd4j.linalg.dataset.MultiDataSet(Nd4j.create(10, nIn), Nd4j.create(10, nOut)));
    for (int i = 0; i < 98; i++) {
        if (i % 2 == 0) {
            dataSets.add(new org.nd4j.linalg.dataset.MultiDataSet(Nd4j.create(5, nIn), Nd4j.create(5, nOut)));
        } else {
            dataSets.add(new org.nd4j.linalg.dataset.MultiDataSet(Nd4j.create(1, nIn), Nd4j.create(1, nOut)));
            dataSets.add(new org.nd4j.linalg.dataset.MultiDataSet(Nd4j.create(1, nIn), Nd4j.create(1, nOut)));
            dataSets.add(new org.nd4j.linalg.dataset.MultiDataSet(Nd4j.create(3, nIn), Nd4j.create(3, nOut)));
        }
    }
    Collections.shuffle(dataSets, new Random(12345));
    JavaRDD<MultiDataSet> rdd = sc.parallelize(dataSets);
    //For testing purposes (should get exactly 100 out, but maybe more with more partitions)
    rdd = rdd.repartition(1);
    JavaRDD<String> pathsRdd = rdd.mapPartitionsWithIndex(new BatchAndExportMultiDataSetsFunction(minibatchSize, "file:///" + baseDir), true);
    List<String> paths = pathsRdd.collect();
    assertEquals(100, paths.size());
    File[] files = f.listFiles();
    assertNotNull(files);
    int count = 0;
    for (File file : files) {
        if (!file.getPath().endsWith(".bin"))
            continue;
        System.out.println(file);
        MultiDataSet ds = new org.nd4j.linalg.dataset.MultiDataSet();
        ds.load(file);
        assertEquals(minibatchSize, ds.getFeatures(0).size(0));
        assertEquals(minibatchSize, ds.getLabels(0).size(0));
        count++;
    }
    assertEquals(100, count);
    FileUtils.deleteDirectory(f);
}
Also used : BatchAndExportMultiDataSetsFunction(org.deeplearning4j.spark.data.BatchAndExportMultiDataSetsFunction) ArrayList(java.util.ArrayList) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) Random(java.util.Random) File(java.io.File) Test(org.junit.Test) BaseSparkTest(org.deeplearning4j.spark.BaseSparkTest)

Example 10 with MultiDataSet

use of org.nd4j.linalg.dataset.api.MultiDataSet in project deeplearning4j by deeplearning4j.

the class ScoreExamplesFunctionAdapter method call.

@Override
public Iterable<Double> call(Iterator<MultiDataSet> iterator) throws Exception {
    if (!iterator.hasNext()) {
        return Collections.emptyList();
    }
    ComputationGraph network = new ComputationGraph(ComputationGraphConfiguration.fromJson(jsonConfig.getValue()));
    network.init();
    INDArray val = params.value().unsafeDuplication();
    if (val.length() != network.numParams(false))
        throw new IllegalStateException("Network did not have same number of parameters as the broadcast set parameters");
    network.setParams(val);
    List<Double> ret = new ArrayList<>();
    List<MultiDataSet> collect = new ArrayList<>(batchSize);
    int totalCount = 0;
    while (iterator.hasNext()) {
        collect.clear();
        int nExamples = 0;
        while (iterator.hasNext() && nExamples < batchSize) {
            MultiDataSet ds = iterator.next();
            int n = ds.getFeatures(0).size(0);
            collect.add(ds);
            nExamples += n;
        }
        totalCount += nExamples;
        MultiDataSet data = org.nd4j.linalg.dataset.MultiDataSet.merge(collect);
        INDArray scores = network.scoreExamples(data, addRegularization);
        double[] doubleScores = scores.data().asDouble();
        for (double doubleScore : doubleScores) {
            ret.add(doubleScore);
        }
    }
    if (Nd4j.getExecutioner() instanceof GridExecutioner)
        ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
    if (log.isDebugEnabled()) {
        log.debug("Scored {} examples ", totalCount);
    }
    return ret;
}
Also used : GridExecutioner(org.nd4j.linalg.api.ops.executioner.GridExecutioner) INDArray(org.nd4j.linalg.api.ndarray.INDArray) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) ArrayList(java.util.ArrayList) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph)

Aggregations

MultiDataSet (org.nd4j.linalg.dataset.api.MultiDataSet)28 Test (org.junit.Test)12 ClassPathResource (org.nd4j.linalg.io.ClassPathResource)11 INDArray (org.nd4j.linalg.api.ndarray.INDArray)10 MultiDataSetIterator (org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator)10 SequenceRecordReader (org.datavec.api.records.reader.SequenceRecordReader)8 CSVSequenceRecordReader (org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader)8 DataSet (org.nd4j.linalg.dataset.DataSet)8 FileSplit (org.datavec.api.split.FileSplit)7 ImageRecordReader (org.datavec.image.recordreader.ImageRecordReader)6 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)6 RecordReader (org.datavec.api.records.reader.RecordReader)5 CSVRecordReader (org.datavec.api.records.reader.impl.csv.CSVRecordReader)5 NumberedFileInputSplit (org.datavec.api.split.NumberedFileInputSplit)5 ArrayList (java.util.ArrayList)4 RecordMetaData (org.datavec.api.records.metadata.RecordMetaData)4 GridExecutioner (org.nd4j.linalg.api.ops.executioner.GridExecutioner)4 AsyncMultiDataSetIterator (org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator)3 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)3 Random (java.util.Random)2