Search in sources :

Example 1 with FastDistanceSparseData

use of com.alibaba.alink.operator.common.distance.FastDistanceSparseData in project Alink by alibaba.

the class VectorModelDataConverter method buildIndex.

@Override
public DataSet<Row> buildIndex(BatchOperator in, Params params) {
    DataSet<Row> dataSet = in.getDataSet();
    FastDistance fastDistance = params.get(HasFastDistanceType.DISTANCE_TYPE).getFastDistance();
    DataSet<Row> index = dataSet.mapPartition(new RichMapPartitionFunction<Row, Row>() {

        private static final long serialVersionUID = -6035963841026118219L;

        @Override
        public void mapPartition(Iterable<Row> values, Collector<Row> out) throws Exception {
            List<FastDistanceData> list = fastDistance.prepareMatrixData(values, 1, 0);
            for (FastDistanceData fastDistanceData : list) {
                Row row = new Row(ROW_SIZE);
                if (fastDistanceData instanceof FastDistanceMatrixData) {
                    row.setField(FASTDISTANCE_TYPE_INDEX, 1L);
                    FastDistanceMatrixData data = (FastDistanceMatrixData) fastDistanceData;
                    row.setField(DATA_INDEX, data.toString());
                } else if (fastDistanceData instanceof FastDistanceVectorData) {
                    row.setField(FASTDISTANCE_TYPE_INDEX, 2L);
                    FastDistanceVectorData data = (FastDistanceVectorData) fastDistanceData;
                    row.setField(DATA_INDEX, data.toString());
                } else if (fastDistanceData instanceof FastDistanceSparseData) {
                    row.setField(FASTDISTANCE_TYPE_INDEX, 3L);
                    FastDistanceSparseData data = (FastDistanceSparseData) fastDistanceData;
                    row.setField(DATA_INDEX, data.toString());
                } else {
                    throw new RuntimeException(fastDistanceData.getClass().getName() + "is not supported!");
                }
                out.collect(row);
            }
        }
    });
    return index.mapPartition(new RichMapPartitionFunction<Row, Row>() {

        private static final long serialVersionUID = 661383020005730224L;

        @Override
        public void mapPartition(Iterable<Row> values, Collector<Row> out) throws Exception {
            Params meta = null;
            if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
                meta = params;
            }
            new VectorModelDataConverter().save(Tuple2.of(meta, values), out);
        }
    }).name("build_model");
}
Also used : RichMapPartitionFunction(org.apache.flink.api.common.functions.RichMapPartitionFunction) Params(org.apache.flink.ml.api.misc.param.Params) FastDistanceVectorData(com.alibaba.alink.operator.common.distance.FastDistanceVectorData) FastDistanceSparseData(com.alibaba.alink.operator.common.distance.FastDistanceSparseData) FastDistanceMatrixData(com.alibaba.alink.operator.common.distance.FastDistanceMatrixData) FastDistance(com.alibaba.alink.operator.common.distance.FastDistance) Collector(org.apache.flink.util.Collector) FastDistanceData(com.alibaba.alink.operator.common.distance.FastDistanceData) ArrayList(java.util.ArrayList) List(java.util.List) Row(org.apache.flink.types.Row)

Aggregations

FastDistance (com.alibaba.alink.operator.common.distance.FastDistance)1 FastDistanceData (com.alibaba.alink.operator.common.distance.FastDistanceData)1 FastDistanceMatrixData (com.alibaba.alink.operator.common.distance.FastDistanceMatrixData)1 FastDistanceSparseData (com.alibaba.alink.operator.common.distance.FastDistanceSparseData)1 FastDistanceVectorData (com.alibaba.alink.operator.common.distance.FastDistanceVectorData)1 ArrayList (java.util.ArrayList)1 List (java.util.List)1 RichMapPartitionFunction (org.apache.flink.api.common.functions.RichMapPartitionFunction)1 Params (org.apache.flink.ml.api.misc.param.Params)1 Row (org.apache.flink.types.Row)1 Collector (org.apache.flink.util.Collector)1