use of org.datavec.spark.transform.misc.StringToWritablesFunction in project deeplearning4j by deeplearning4j.
the class TestDataVecDataSetFunctions method testDataVecDataSetFunctionMultiLabelRegression.
@Test
public void testDataVecDataSetFunctionMultiLabelRegression() throws Exception {
JavaSparkContext sc = getContext();
List<String> stringData = new ArrayList<>();
int n = 6;
for (int i = 0; i < 10; i++) {
StringBuilder sb = new StringBuilder();
boolean first = true;
for (int j = 0; j < n; j++) {
if (!first)
sb.append(",");
sb.append(10 * i + j);
first = false;
}
stringData.add(sb.toString());
}
JavaRDD<String> stringList = sc.parallelize(stringData);
JavaRDD<List<Writable>> writables = stringList.map(new StringToWritablesFunction(new CSVRecordReader()));
JavaRDD<DataSet> dataSets = writables.map(new DataVecDataSetFunction(3, 5, -1, true, null, null));
List<DataSet> ds = dataSets.collect();
assertEquals(10, ds.size());
boolean[] seen = new boolean[10];
for (DataSet d : ds) {
INDArray f = d.getFeatureMatrix();
INDArray l = d.getLabels();
assertEquals(3, f.length());
assertEquals(3, l.length());
int exampleIdx = ((int) f.getDouble(0)) / 10;
seen[exampleIdx] = true;
for (int j = 0; j < 3; j++) {
assertEquals(10 * exampleIdx + j, (int) f.getDouble(j));
assertEquals(10 * exampleIdx + j + 3, (int) l.getDouble(j));
}
}
int seenCount = 0;
for (boolean b : seen) if (b)
seenCount++;
assertEquals(10, seenCount);
}
Aggregations