Search in sources :

Example 1 with SequenceRecordReaderFunction

use of org.datavec.spark.functions.SequenceRecordReaderFunction in project deeplearning4j by deeplearning4j.

the class TestDataVecDataSetFunctions method testDataVecSequenceDataSetFunction.

@Test
public void testDataVecSequenceDataSetFunction() throws Exception {
    JavaSparkContext sc = getContext();
    //Test Spark record reader functionality vs. local
    File f = new File("src/test/resources/csvsequence/csvsequence_0.txt");
    String path = f.getPath();
    String folder = path.substring(0, path.length() - 17);
    path = folder + "*";
    JavaPairRDD<String, PortableDataStream> origData = sc.binaryFiles(path);
    //3 CSV sequences
    assertEquals(3, origData.count());
    SequenceRecordReader seqRR = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReaderFunction rrf = new SequenceRecordReaderFunction(seqRR);
    JavaRDD<List<List<Writable>>> rdd = origData.map(rrf);
    JavaRDD<DataSet> data = rdd.map(new DataVecSequenceDataSetFunction(2, -1, true, null, null));
    List<DataSet> collected = data.collect();
    //Load normally (i.e., not via Spark), and check that we get the same results (order not withstanding)
    InputSplit is = new FileSplit(new File(folder), new String[] { "txt" }, true);
    SequenceRecordReader seqRR2 = new CSVSequenceRecordReader(1, ",");
    seqRR2.initialize(is);
    SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(seqRR2, 1, -1, 2, true);
    List<DataSet> listLocal = new ArrayList<>(3);
    while (iter.hasNext()) {
        listLocal.add(iter.next());
    }
    //Compare:
    assertEquals(3, collected.size());
    assertEquals(3, listLocal.size());
    //Check that results are the same (order not withstanding)
    boolean[] found = new boolean[3];
    for (int i = 0; i < 3; i++) {
        int foundIndex = -1;
        DataSet ds = collected.get(i);
        for (int j = 0; j < 3; j++) {
            if (ds.equals(listLocal.get(j))) {
                if (foundIndex != -1)
                    //Already found this value -> suggests this spark value equals two or more of local version? (Shouldn't happen)
                    fail();
                foundIndex = j;
                if (found[foundIndex])
                    //One of the other spark values was equal to this one -> suggests duplicates in Spark list
                    fail();
                //mark this one as seen before
                found[foundIndex] = true;
            }
        }
    }
    int count = 0;
    for (boolean b : found) if (b)
        count++;
    //Expect all 3 and exactly 3 pairwise matches between spark and local versions
    assertEquals(3, count);
}
Also used : CSVSequenceRecordReader(org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader) SequenceRecordReader(org.datavec.api.records.reader.SequenceRecordReader) SequenceRecordReaderFunction(org.datavec.spark.functions.SequenceRecordReaderFunction) DataSet(org.nd4j.linalg.dataset.DataSet) SequenceRecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator) ArrayList(java.util.ArrayList) PortableDataStream(org.apache.spark.input.PortableDataStream) Writable(org.datavec.api.writable.Writable) FileSplit(org.datavec.api.split.FileSplit) CSVSequenceRecordReader(org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader) ArrayList(java.util.ArrayList) List(java.util.List) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) File(java.io.File) NumberedFileInputSplit(org.datavec.api.split.NumberedFileInputSplit) InputSplit(org.datavec.api.split.InputSplit) BaseSparkTest(org.deeplearning4j.spark.BaseSparkTest) Test(org.junit.Test)

Aggregations

File (java.io.File)1 ArrayList (java.util.ArrayList)1 List (java.util.List)1 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)1 PortableDataStream (org.apache.spark.input.PortableDataStream)1 SequenceRecordReader (org.datavec.api.records.reader.SequenceRecordReader)1 CSVSequenceRecordReader (org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader)1 FileSplit (org.datavec.api.split.FileSplit)1 InputSplit (org.datavec.api.split.InputSplit)1 NumberedFileInputSplit (org.datavec.api.split.NumberedFileInputSplit)1 Writable (org.datavec.api.writable.Writable)1 SequenceRecordReaderFunction (org.datavec.spark.functions.SequenceRecordReaderFunction)1 SequenceRecordReaderDataSetIterator (org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator)1 BaseSparkTest (org.deeplearning4j.spark.BaseSparkTest)1 Test (org.junit.Test)1 DataSet (org.nd4j.linalg.dataset.DataSet)1