Search in sources :

Example 36 with DataSet

use of org.nd4j.linalg.dataset.DataSet in project deeplearning4j by deeplearning4j.

the class TestDataVecDataSetFunctions method testDataVecDataSetFunction.

@Test
public void testDataVecDataSetFunction() throws Exception {
    JavaSparkContext sc = getContext();
    //Test Spark record reader functionality vs. local
    File f = new File("src/test/resources/imagetest/0/a.bmp");
    //Need this for Spark: can't infer without init call
    List<String> labelsList = Arrays.asList("0", "1");
    String path = f.getPath();
    String folder = path.substring(0, path.length() - 7);
    path = folder + "*";
    JavaPairRDD<String, PortableDataStream> origData = sc.binaryFiles(path);
    //4 images
    assertEquals(4, origData.count());
    ImageRecordReader rr = new ImageRecordReader(28, 28, 1, new ParentPathLabelGenerator());
    rr.setLabels(labelsList);
    org.datavec.spark.functions.RecordReaderFunction rrf = new org.datavec.spark.functions.RecordReaderFunction(rr);
    JavaRDD<List<Writable>> rdd = origData.map(rrf);
    JavaRDD<DataSet> data = rdd.map(new DataVecDataSetFunction(1, 2, false));
    List<DataSet> collected = data.collect();
    //Load normally (i.e., not via Spark), and check that we get the same results (order not withstanding)
    InputSplit is = new FileSplit(new File(folder), new String[] { "bmp" }, true);
    ImageRecordReader irr = new ImageRecordReader(28, 28, 1, new ParentPathLabelGenerator());
    irr.initialize(is);
    RecordReaderDataSetIterator iter = new RecordReaderDataSetIterator(irr, 1, 1, 2);
    List<DataSet> listLocal = new ArrayList<>(4);
    while (iter.hasNext()) {
        listLocal.add(iter.next());
    }
    //Compare:
    assertEquals(4, collected.size());
    assertEquals(4, listLocal.size());
    //Check that results are the same (order not withstanding)
    boolean[] found = new boolean[4];
    for (int i = 0; i < 4; i++) {
        int foundIndex = -1;
        DataSet ds = collected.get(i);
        for (int j = 0; j < 4; j++) {
            if (ds.equals(listLocal.get(j))) {
                if (foundIndex != -1)
                    //Already found this value -> suggests this spark value equals two or more of local version? (Shouldn't happen)
                    fail();
                foundIndex = j;
                if (found[foundIndex])
                    //One of the other spark values was equal to this one -> suggests duplicates in Spark list
                    fail();
                //mark this one as seen before
                found[foundIndex] = true;
            }
        }
    }
    int count = 0;
    for (boolean b : found) if (b)
        count++;
    //Expect all 4 and exactly 4 pairwise matches between spark and local versions
    assertEquals(4, count);
}
Also used : SequenceRecordReaderFunction(org.datavec.spark.functions.SequenceRecordReaderFunction) DataSet(org.nd4j.linalg.dataset.DataSet) ArrayList(java.util.ArrayList) PortableDataStream(org.apache.spark.input.PortableDataStream) SequenceRecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator) RecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator) FileSplit(org.datavec.api.split.FileSplit) ArrayList(java.util.ArrayList) List(java.util.List) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) NumberedFileInputSplit(org.datavec.api.split.NumberedFileInputSplit) InputSplit(org.datavec.api.split.InputSplit) ParentPathLabelGenerator(org.datavec.api.io.labels.ParentPathLabelGenerator) File(java.io.File) ImageRecordReader(org.datavec.image.recordreader.ImageRecordReader) BaseSparkTest(org.deeplearning4j.spark.BaseSparkTest) Test(org.junit.Test)

Example 37 with DataSet

use of org.nd4j.linalg.dataset.DataSet in project deeplearning4j by deeplearning4j.

the class PathSparkDataSetIterator method load.

protected synchronized DataSet load(String path) {
    if (fileSystem == null) {
        try {
            fileSystem = FileSystem.get(new URI(path), new Configuration());
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
    DataSet ds = new DataSet();
    try (FSDataInputStream inputStream = fileSystem.open(new Path(path), BUFFER_SIZE)) {
        ds.load(inputStream);
    } catch (IOException e) {
        throw new RuntimeException(e);
    }
    cursor++;
    return ds;
}
Also used : Path(org.apache.hadoop.fs.Path) Configuration(org.apache.hadoop.conf.Configuration) DataSet(org.nd4j.linalg.dataset.DataSet) FSDataInputStream(org.apache.hadoop.fs.FSDataInputStream) IOException(java.io.IOException) URI(java.net.URI) IOException(java.io.IOException)

Example 38 with DataSet

use of org.nd4j.linalg.dataset.DataSet in project deeplearning4j by deeplearning4j.

the class ScoreFlatMapFunctionCGDataSetAdapter method call.

@Override
public Iterable<Tuple2<Integer, Double>> call(Iterator<DataSet> dataSetIterator) throws Exception {
    if (!dataSetIterator.hasNext()) {
        return Collections.singletonList(new Tuple2<>(0, 0.0));
    }
    //Does batching where appropriate
    DataSetIterator iter = new IteratorDataSetIterator(dataSetIterator, minibatchSize);
    ComputationGraph network = new ComputationGraph(ComputationGraphConfiguration.fromJson(json));
    network.init();
    //.value() is shared by all executors on single machine -> OK, as params are not changed in score function
    INDArray val = params.value().unsafeDuplication();
    if (val.length() != network.numParams(false))
        throw new IllegalStateException("Network did not have same number of parameters as the broadcast set parameters");
    network.setParams(val);
    List<Tuple2<Integer, Double>> out = new ArrayList<>();
    while (iter.hasNext()) {
        DataSet ds = iter.next();
        double score = network.score(ds, false);
        int numExamples = ds.getFeatureMatrix().size(0);
        out.add(new Tuple2<>(numExamples, score * numExamples));
    }
    if (Nd4j.getExecutioner() instanceof GridExecutioner)
        ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
    return out;
}
Also used : DataSet(org.nd4j.linalg.dataset.DataSet) ArrayList(java.util.ArrayList) GridExecutioner(org.nd4j.linalg.api.ops.executioner.GridExecutioner) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Tuple2(scala.Tuple2) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) IteratorDataSetIterator(org.deeplearning4j.datasets.iterator.IteratorDataSetIterator) IteratorDataSetIterator(org.deeplearning4j.datasets.iterator.IteratorDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator)

Example 39 with DataSet

use of org.nd4j.linalg.dataset.DataSet in project deeplearning4j by deeplearning4j.

the class SparkDl4jLayer method fit.

/**
     * Fit the layer based on the specified org.deeplearning4j.spark context text file
     * @param path the path to the text file
     * @param labelIndex the index of the label
     * @param recordReader the record reader
     * @return the fit layer
     */
public Layer fit(String path, int labelIndex, RecordReader recordReader) {
    FeedForwardLayer ffLayer = (FeedForwardLayer) conf.getLayer();
    JavaRDD<String> lines = sc.textFile(path);
    // gotta map this to a Matrix/INDArray
    JavaRDD<DataSet> points = lines.map(new RecordReaderFunction(recordReader, labelIndex, ffLayer.getNOut()));
    return fitDataSet(points);
}
Also used : RecordReaderFunction(org.deeplearning4j.spark.datavec.RecordReaderFunction) DataSet(org.nd4j.linalg.dataset.DataSet) FeedForwardLayer(org.deeplearning4j.nn.conf.layers.FeedForwardLayer)

Example 40 with DataSet

use of org.nd4j.linalg.dataset.DataSet in project deeplearning4j by deeplearning4j.

the class PathSparkDataSetIterator method next.

@Override
public DataSet next() {
    DataSet ds;
    if (preloadedDataSet != null) {
        ds = preloadedDataSet;
        preloadedDataSet = null;
    } else {
        ds = load(iter.next());
    }
    //May be null for layerwise pretraining
    totalOutcomes = ds.getLabels() == null ? 0 : ds.getLabels().size(1);
    inputColumns = ds.getFeatureMatrix().size(1);
    batch = ds.numExamples();
    if (preprocessor != null)
        preprocessor.preProcess(ds);
    return ds;
}
Also used : DataSet(org.nd4j.linalg.dataset.DataSet)

Aggregations

DataSet (org.nd4j.linalg.dataset.DataSet)334 Test (org.junit.Test)226 INDArray (org.nd4j.linalg.api.ndarray.INDArray)194 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)93 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)82 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)79 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)73 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)62 ArrayList (java.util.ArrayList)50 MnistDataSetIterator (org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator)41 ScoreIterationListener (org.deeplearning4j.optimize.listeners.ScoreIterationListener)38 BaseSparkTest (org.deeplearning4j.spark.BaseSparkTest)34 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)32 DenseLayer (org.deeplearning4j.nn.conf.layers.DenseLayer)31 MultiDataSet (org.nd4j.linalg.dataset.MultiDataSet)31 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)25 SequenceRecordReader (org.datavec.api.records.reader.SequenceRecordReader)24 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)24 CSVSequenceRecordReader (org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader)23 ClassPathResource (org.nd4j.linalg.io.ClassPathResource)23