Search in sources :

Example 1 with FlatMapFunctionAdapter

use of org.datavec.spark.functions.FlatMapFunctionAdapter in project deeplearning4j by deeplearning4j.

the class MLLibUtil method fromLabeledPoint.

/**
     * Convert an rdd
     * of labeled point
     * based on the specified batch size
     * in to data set
     * @param data the data to convert
     * @param numPossibleLabels the number of possible labels
     * @param batchSize the batch size
     * @return the new rdd
     */
public static JavaRDD<DataSet> fromLabeledPoint(JavaRDD<LabeledPoint> data, final int numPossibleLabels, int batchSize) {
    //map by index
    JavaPairRDD<Long, LabeledPoint> dataWithIndex = data.zipWithIndex().mapToPair(new PairFunction<Tuple2<LabeledPoint, Long>, Long, LabeledPoint>() {

        @Override
        public Tuple2<Long, LabeledPoint> call(Tuple2<LabeledPoint, Long> labeledPointLongTuple2) throws Exception {
            return new Tuple2<>(labeledPointLongTuple2._2(), labeledPointLongTuple2._1());
        }
    });
    JavaPairRDD<Long, DataSet> mappedData = dataWithIndex.mapToPair(new PairFunction<Tuple2<Long, LabeledPoint>, Long, DataSet>() {

        @Override
        public Tuple2<Long, DataSet> call(Tuple2<Long, LabeledPoint> longLabeledPointTuple2) throws Exception {
            return new Tuple2<>(longLabeledPointTuple2._1(), MLLibUtil.fromLabeledPoint(longLabeledPointTuple2._2(), numPossibleLabels));
        }
    });
    JavaPairRDD<Long, DataSet> aggregated = mappedData.reduceByKey(new Function2<DataSet, DataSet, DataSet>() {

        @Override
        public DataSet call(DataSet v1, DataSet v2) throws Exception {
            return new DataSet(Nd4j.vstack(v1.getFeatureMatrix(), v2.getFeatureMatrix()), Nd4j.vstack(v1.getLabels(), v2.getLabels()));
        }
    }, (int) (mappedData.count() / batchSize));
    JavaRDD<DataSet> data2 = aggregated.flatMap(new BaseFlatMapFunctionAdaptee<Tuple2<Long, DataSet>, DataSet>(new FlatMapFunctionAdapter<Tuple2<Long, DataSet>, DataSet>() {

        @Override
        public Iterable<DataSet> call(Tuple2<Long, DataSet> longDataSetTuple2) throws Exception {
            return longDataSetTuple2._2();
        }
    }));
    return data2;
}
Also used : DataSet(org.nd4j.linalg.dataset.DataSet) FlatMapFunctionAdapter(org.datavec.spark.functions.FlatMapFunctionAdapter) LabeledPoint(org.apache.spark.mllib.regression.LabeledPoint) Tuple2(scala.Tuple2)

Aggregations

LabeledPoint (org.apache.spark.mllib.regression.LabeledPoint)1 FlatMapFunctionAdapter (org.datavec.spark.functions.FlatMapFunctionAdapter)1 DataSet (org.nd4j.linalg.dataset.DataSet)1 Tuple2 (scala.Tuple2)1