use of com.alibaba.alink.operator.common.distance.FastDistance in project Alink by alibaba.
the class GeoKMeansTrainBatchOp method linkFrom.
@Override
public GeoKMeansTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
BatchOperator<?> in = checkAndGetFirst(inputs);
final String latitudeColName = this.getLatitudeCol();
final String longitudeColName = this.getLongitudeCol();
FastDistance distance = new HaversineDistance();
final int maxIter = this.getMaxIter();
final double tol = this.getEpsilon();
DataSet<FastDistanceVectorData> data = in.select(new String[] { latitudeColName, longitudeColName }).getDataSet().rebalance().map(new MapFunction<Row, FastDistanceVectorData>() {
private static final long serialVersionUID = -5236022856006527961L;
@Override
public FastDistanceVectorData map(Row row) {
Vector vec = new DenseVector(new double[] { ((Number) row.getField(0)).doubleValue(), ((Number) row.getField(1)).doubleValue() });
return distance.prepareVectorData(Row.of(vec), 0);
}
});
DataSet<Integer> vectorSize = MLEnvironmentFactory.get(getMLEnvironmentId()).getExecutionEnvironment().fromElements(2);
// Tuple3: clusterId, clusterWeight, clusterCentroid
DataSet<FastDistanceMatrixData> initCentroid = initKmeansCentroids(data, distance, this.getParams(), vectorSize, getRandomSeed());
DataSet<Row> finalCentroid = iterateICQ(initCentroid, data, vectorSize, maxIter, tol, distance, HasKMeansWithHaversineDistanceType.DistanceType.HAVERSINE, null, this.getLatitudeCol(), this.getLongitudeCol());
// store the clustering model to the table
this.setOutput(finalCentroid, new KMeansModelDataConverter().getModelSchema());
return this;
}
use of com.alibaba.alink.operator.common.distance.FastDistance in project Alink by alibaba.
the class KMeansTrainBatchOp method linkFrom.
@Override
public KMeansTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
BatchOperator<?> in = checkAndGetFirst(inputs);
final int maxIter = this.getMaxIter();
final double tol = this.getEpsilon();
final String vectorColName = this.getVectorCol();
final DistanceType distanceType = getDistanceType();
FastDistance distance = distanceType.getFastDistance();
Tuple2<DataSet<Vector>, DataSet<BaseVectorSummary>> statistics = StatisticsHelper.summaryHelper(in, null, vectorColName);
DataSet<Integer> vectorSize = statistics.f1.map(new MapFunction<BaseVectorSummary, Integer>() {
private static final long serialVersionUID = 4184586558834055401L;
@Override
public Integer map(BaseVectorSummary value) {
Preconditions.checkArgument(value.count() > 0, "The train dataset is empty!");
return value.vectorSize();
}
});
DataSet<FastDistanceVectorData> data = statistics.f0.rebalance().map(new RichMapFunction<Vector, FastDistanceVectorData>() {
private static final long serialVersionUID = -7443226889326704768L;
private int vectorSize;
@Override
public void open(Configuration params) {
vectorSize = (int) this.getRuntimeContext().getBroadcastVariable(VECTOR_SIZE).get(0);
}
@Override
public FastDistanceVectorData map(Vector value) {
if (value instanceof SparseVector) {
((SparseVector) value).setSize(vectorSize);
}
return distance.prepareVectorData(Row.of(value), 0);
}
}).withBroadcastSet(vectorSize, VECTOR_SIZE);
DataSet<FastDistanceMatrixData> initCentroid = KMeansInitCentroids.initKmeansCentroids(data, distance, this.getParams(), vectorSize, getRandomSeed());
DataSet<Row> finalCentroid = iterateICQ(initCentroid, data, vectorSize, maxIter, tol, distance, HasKMeansWithHaversineDistanceType.DistanceType.valueOf(distanceType.name()), vectorColName, null, null);
this.setOutput(finalCentroid, new KMeansModelDataConverter().getModelSchema());
return this;
}
use of com.alibaba.alink.operator.common.distance.FastDistance in project Alink by alibaba.
the class VectorModelDataConverter method buildIndex.
@Override
public DataSet<Row> buildIndex(BatchOperator in, Params params) {
DataSet<Row> dataSet = in.getDataSet();
FastDistance fastDistance = params.get(HasFastDistanceType.DISTANCE_TYPE).getFastDistance();
DataSet<Row> index = dataSet.mapPartition(new RichMapPartitionFunction<Row, Row>() {
private static final long serialVersionUID = -6035963841026118219L;
@Override
public void mapPartition(Iterable<Row> values, Collector<Row> out) throws Exception {
List<FastDistanceData> list = fastDistance.prepareMatrixData(values, 1, 0);
for (FastDistanceData fastDistanceData : list) {
Row row = new Row(ROW_SIZE);
if (fastDistanceData instanceof FastDistanceMatrixData) {
row.setField(FASTDISTANCE_TYPE_INDEX, 1L);
FastDistanceMatrixData data = (FastDistanceMatrixData) fastDistanceData;
row.setField(DATA_INDEX, data.toString());
} else if (fastDistanceData instanceof FastDistanceVectorData) {
row.setField(FASTDISTANCE_TYPE_INDEX, 2L);
FastDistanceVectorData data = (FastDistanceVectorData) fastDistanceData;
row.setField(DATA_INDEX, data.toString());
} else if (fastDistanceData instanceof FastDistanceSparseData) {
row.setField(FASTDISTANCE_TYPE_INDEX, 3L);
FastDistanceSparseData data = (FastDistanceSparseData) fastDistanceData;
row.setField(DATA_INDEX, data.toString());
} else {
throw new RuntimeException(fastDistanceData.getClass().getName() + "is not supported!");
}
out.collect(row);
}
}
});
return index.mapPartition(new RichMapPartitionFunction<Row, Row>() {
private static final long serialVersionUID = 661383020005730224L;
@Override
public void mapPartition(Iterable<Row> values, Collector<Row> out) throws Exception {
Params meta = null;
if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
meta = params;
}
new VectorModelDataConverter().save(Tuple2.of(meta, values), out);
}
}).name("build_model");
}
use of com.alibaba.alink.operator.common.distance.FastDistance in project Alink by alibaba.
the class EvalClusterBatchOp method linkFrom.
@Override
public EvalClusterBatchOp linkFrom(BatchOperator<?>... inputs) {
BatchOperator in = checkAndGetFirst(inputs);
String labelColName = this.getLabelCol();
String predResultColName = this.getPredictionCol();
String vectorColName = this.getVectorCol();
DistanceType distanceType = getDistanceType();
FastDistance distance = distanceType.getFastDistance();
DataSet<Params> empty = MLEnvironmentFactory.get(getMLEnvironmentId()).getExecutionEnvironment().fromElements(new Params());
DataSet<Params> labelMetrics = empty, vectorMetrics;
if (null != labelColName) {
DataSet<Row> data = in.select(new String[] { labelColName, predResultColName }).getDataSet();
DataSet<Tuple1<Map<Object, Integer>>> labels = data.flatMap(new FlatMapFunction<Row, Object>() {
private static final long serialVersionUID = 6181506719667975996L;
@Override
public void flatMap(Row row, Collector<Object> collector) {
if (EvaluationUtil.checkRowFieldNotNull(row)) {
collector.collect(row.getField(0));
}
}
}).reduceGroup(new EvaluationUtil.DistinctLabelIndexMap(false, null, null, false)).project(0);
DataSet<Tuple1<Map<Object, Integer>>> predictions = data.flatMap(new FlatMapFunction<Row, Object>() {
private static final long serialVersionUID = 619373417169823128L;
@Override
public void flatMap(Row row, Collector<Object> collector) {
if (EvaluationUtil.checkRowFieldNotNull(row)) {
collector.collect(row.getField(1));
}
}
}).reduceGroup(new EvaluationUtil.DistinctLabelIndexMap(false, null, null, false)).project(0);
labelMetrics = data.rebalance().mapPartition(new CalLocalPredResult()).withBroadcastSet(labels, LABELS).withBroadcastSet(predictions, PREDICTIONS).reduce(new ReduceFunction<LongMatrix>() {
private static final long serialVersionUID = 3340266128816528106L;
@Override
public LongMatrix reduce(LongMatrix value1, LongMatrix value2) {
value1.plusEqual(value2);
return value1;
}
}).map(new RichMapFunction<LongMatrix, Params>() {
private static final long serialVersionUID = -4218363116865487327L;
@Override
public Params map(LongMatrix value) {
List<Tuple1<Map<Object, Integer>>> labels = getRuntimeContext().getBroadcastVariable(LABELS);
List<Tuple1<Map<Object, Integer>>> predictions = getRuntimeContext().getBroadcastVariable(PREDICTIONS);
return ClusterEvaluationUtil.extractParamsFromConfusionMatrix(value, labels.get(0).f0, predictions.get(0).f0);
}
}).withBroadcastSet(labels, LABELS).withBroadcastSet(predictions, PREDICTIONS);
}
if (null != vectorColName) {
Tuple2<DataSet<Tuple2<Vector, Row>>, DataSet<BaseVectorSummary>> statistics = StatisticsHelper.summaryHelper(inputs[0], null, vectorColName, new String[] { predResultColName });
DataSet<Tuple2<Vector, String>> nonEmpty = statistics.f0.flatMap(new FilterEmptyRow()).withBroadcastSet(statistics.f1, VECTOR_SIZE);
DataSet<Tuple3<String, DenseVector, DenseVector>> meanAndSum = nonEmpty.groupBy(1).reduceGroup(new CalcMeanAndSum(distance)).withBroadcastSet(statistics.f1, VECTOR_SIZE);
DataSet<BaseMetricsSummary> metricsSummary = nonEmpty.coGroup(meanAndSum).where(1).equalTo(0).with(new CalcClusterMetricsSummary(distance)).withBroadcastSet(meanAndSum, MEAN_AND_SUM).reduce(new EvaluationUtil.ReduceBaseMetrics());
DataSet<Tuple1<Double>> silhouetteCoefficient = nonEmpty.map(new RichMapFunction<Tuple2<Vector, String>, Tuple1<Double>>() {
private static final long serialVersionUID = 116926378586242272L;
@Override
public Tuple1<Double> map(Tuple2<Vector, String> value) {
List<BaseMetricsSummary> list = getRuntimeContext().getBroadcastVariable(METRICS_SUMMARY);
return ClusterEvaluationUtil.calSilhouetteCoefficient(value, (ClusterMetricsSummary) list.get(0));
}
}).withBroadcastSet(metricsSummary, METRICS_SUMMARY).aggregate(Aggregations.SUM, 0);
vectorMetrics = metricsSummary.map(new ClusterEvaluationUtil.SaveDataAsParams()).withBroadcastSet(silhouetteCoefficient, SILHOUETTE_COEFFICIENT);
} else {
vectorMetrics = in.select(predResultColName).getDataSet().reduceGroup(new BasicClusterParams());
}
DataSet<Row> out = labelMetrics.union(vectorMetrics).reduceGroup(new GroupReduceFunction<Params, Row>() {
private static final long serialVersionUID = -4726713311986089251L;
@Override
public void reduce(Iterable<Params> values, Collector<Row> out) {
Params params = new Params();
for (Params p : values) {
params.merge(p);
}
out.collect(Row.of(params.toJson()));
}
});
this.setOutputTable(DataSetConversionUtil.toTable(getMLEnvironmentId(), out, new TableSchema(new String[] { EVAL_RESULT }, new TypeInformation[] { Types.STRING })));
return this;
}
Aggregations