Search in sources :

Example 1 with StringToWritablesFunction

use of org.datavec.spark.transform.misc.StringToWritablesFunction in project deeplearning4j by deeplearning4j.

the class TestDataVecDataSetFunctions method testDataVecDataSetFunctionMultiLabelRegression.

@Test
public void testDataVecDataSetFunctionMultiLabelRegression() throws Exception {
    JavaSparkContext sc = getContext();
    List<String> stringData = new ArrayList<>();
    int n = 6;
    for (int i = 0; i < 10; i++) {
        StringBuilder sb = new StringBuilder();
        boolean first = true;
        for (int j = 0; j < n; j++) {
            if (!first)
                sb.append(",");
            sb.append(10 * i + j);
            first = false;
        }
        stringData.add(sb.toString());
    }
    JavaRDD<String> stringList = sc.parallelize(stringData);
    JavaRDD<List<Writable>> writables = stringList.map(new StringToWritablesFunction(new CSVRecordReader()));
    JavaRDD<DataSet> dataSets = writables.map(new DataVecDataSetFunction(3, 5, -1, true, null, null));
    List<DataSet> ds = dataSets.collect();
    assertEquals(10, ds.size());
    boolean[] seen = new boolean[10];
    for (DataSet d : ds) {
        INDArray f = d.getFeatureMatrix();
        INDArray l = d.getLabels();
        assertEquals(3, f.length());
        assertEquals(3, l.length());
        int exampleIdx = ((int) f.getDouble(0)) / 10;
        seen[exampleIdx] = true;
        for (int j = 0; j < 3; j++) {
            assertEquals(10 * exampleIdx + j, (int) f.getDouble(j));
            assertEquals(10 * exampleIdx + j + 3, (int) l.getDouble(j));
        }
    }
    int seenCount = 0;
    for (boolean b : seen) if (b)
        seenCount++;
    assertEquals(10, seenCount);
}
Also used : StringToWritablesFunction(org.datavec.spark.transform.misc.StringToWritablesFunction) DataSet(org.nd4j.linalg.dataset.DataSet) ArrayList(java.util.ArrayList) INDArray(org.nd4j.linalg.api.ndarray.INDArray) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) ArrayList(java.util.ArrayList) List(java.util.List) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) BaseSparkTest(org.deeplearning4j.spark.BaseSparkTest) Test(org.junit.Test)

Aggregations

ArrayList (java.util.ArrayList)1 List (java.util.List)1 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)1 CSVRecordReader (org.datavec.api.records.reader.impl.csv.CSVRecordReader)1 StringToWritablesFunction (org.datavec.spark.transform.misc.StringToWritablesFunction)1 BaseSparkTest (org.deeplearning4j.spark.BaseSparkTest)1 Test (org.junit.Test)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1 DataSet (org.nd4j.linalg.dataset.DataSet)1