use of com.alibaba.alink.operator.common.distance.FastDistanceSparseData in project Alink by alibaba.
the class VectorModelDataConverter method buildIndex.
@Override
public DataSet<Row> buildIndex(BatchOperator in, Params params) {
DataSet<Row> dataSet = in.getDataSet();
FastDistance fastDistance = params.get(HasFastDistanceType.DISTANCE_TYPE).getFastDistance();
DataSet<Row> index = dataSet.mapPartition(new RichMapPartitionFunction<Row, Row>() {
private static final long serialVersionUID = -6035963841026118219L;
@Override
public void mapPartition(Iterable<Row> values, Collector<Row> out) throws Exception {
List<FastDistanceData> list = fastDistance.prepareMatrixData(values, 1, 0);
for (FastDistanceData fastDistanceData : list) {
Row row = new Row(ROW_SIZE);
if (fastDistanceData instanceof FastDistanceMatrixData) {
row.setField(FASTDISTANCE_TYPE_INDEX, 1L);
FastDistanceMatrixData data = (FastDistanceMatrixData) fastDistanceData;
row.setField(DATA_INDEX, data.toString());
} else if (fastDistanceData instanceof FastDistanceVectorData) {
row.setField(FASTDISTANCE_TYPE_INDEX, 2L);
FastDistanceVectorData data = (FastDistanceVectorData) fastDistanceData;
row.setField(DATA_INDEX, data.toString());
} else if (fastDistanceData instanceof FastDistanceSparseData) {
row.setField(FASTDISTANCE_TYPE_INDEX, 3L);
FastDistanceSparseData data = (FastDistanceSparseData) fastDistanceData;
row.setField(DATA_INDEX, data.toString());
} else {
throw new RuntimeException(fastDistanceData.getClass().getName() + "is not supported!");
}
out.collect(row);
}
}
});
return index.mapPartition(new RichMapPartitionFunction<Row, Row>() {
private static final long serialVersionUID = 661383020005730224L;
@Override
public void mapPartition(Iterable<Row> values, Collector<Row> out) throws Exception {
Params meta = null;
if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
meta = params;
}
new VectorModelDataConverter().save(Tuple2.of(meta, values), out);
}
}).name("build_model");
}
Aggregations