use of org.datavec.spark.functions.FlatMapFunctionAdapter in project deeplearning4j by deeplearning4j.
the class MLLibUtil method fromLabeledPoint.
/**
* Convert an rdd
* of labeled point
* based on the specified batch size
* in to data set
* @param data the data to convert
* @param numPossibleLabels the number of possible labels
* @param batchSize the batch size
* @return the new rdd
*/
public static JavaRDD<DataSet> fromLabeledPoint(JavaRDD<LabeledPoint> data, final int numPossibleLabels, int batchSize) {
//map by index
JavaPairRDD<Long, LabeledPoint> dataWithIndex = data.zipWithIndex().mapToPair(new PairFunction<Tuple2<LabeledPoint, Long>, Long, LabeledPoint>() {
@Override
public Tuple2<Long, LabeledPoint> call(Tuple2<LabeledPoint, Long> labeledPointLongTuple2) throws Exception {
return new Tuple2<>(labeledPointLongTuple2._2(), labeledPointLongTuple2._1());
}
});
JavaPairRDD<Long, DataSet> mappedData = dataWithIndex.mapToPair(new PairFunction<Tuple2<Long, LabeledPoint>, Long, DataSet>() {
@Override
public Tuple2<Long, DataSet> call(Tuple2<Long, LabeledPoint> longLabeledPointTuple2) throws Exception {
return new Tuple2<>(longLabeledPointTuple2._1(), MLLibUtil.fromLabeledPoint(longLabeledPointTuple2._2(), numPossibleLabels));
}
});
JavaPairRDD<Long, DataSet> aggregated = mappedData.reduceByKey(new Function2<DataSet, DataSet, DataSet>() {
@Override
public DataSet call(DataSet v1, DataSet v2) throws Exception {
return new DataSet(Nd4j.vstack(v1.getFeatureMatrix(), v2.getFeatureMatrix()), Nd4j.vstack(v1.getLabels(), v2.getLabels()));
}
}, (int) (mappedData.count() / batchSize));
JavaRDD<DataSet> data2 = aggregated.flatMap(new BaseFlatMapFunctionAdaptee<Tuple2<Long, DataSet>, DataSet>(new FlatMapFunctionAdapter<Tuple2<Long, DataSet>, DataSet>() {
@Override
public Iterable<DataSet> call(Tuple2<Long, DataSet> longDataSetTuple2) throws Exception {
return longDataSetTuple2._2();
}
}));
return data2;
}
Aggregations